diff --git a/calculate_average_tonivade.sh b/calculate_average_tonivade.sh index 5e160f9..a484a53 100755 --- a/calculate_average_tonivade.sh +++ b/calculate_average_tonivade.sh @@ -15,5 +15,5 @@ # limitations under the License. # -JAVA_OPTS="-Xmx1G -Xms1G -XX:+AlwaysPreTouch --enable-preview" +JAVA_OPTS="-Xmx1G -Xms1G -XX:+AlwaysPreTouch -XX:+UseParallelGC -XX:-UseCompressedOops --enable-preview" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_tonivade diff --git a/prepare_tonivade.sh b/prepare_tonivade.sh index 66b23f6..cdf474f 100755 --- a/prepare_tonivade.sh +++ b/prepare_tonivade.sh @@ -17,4 +17,4 @@ # Uncomment below to use sdk source "$HOME/.sdkman/bin/sdkman-init.sh" -sdk use java 21.0.1-tem 1>&2 +sdk use java 21.0.2-tem 1>&2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_tonivade.java b/src/main/java/dev/morling/onebrc/CalculateAverage_tonivade.java index bd28488..9deb3f2 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_tonivade.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_tonivade.java @@ -15,9 +15,6 @@ */ package dev.morling.onebrc; -import static java.util.Comparator.comparing; -import static java.util.stream.Collectors.joining; - import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; @@ -26,9 +23,8 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; import java.util.Map; +import java.util.TreeMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.StructuredTaskScope; import java.util.concurrent.StructuredTaskScope.Subtask; @@ -37,32 +33,16 @@ public class CalculateAverage_tonivade { private static final String FILE = "./measurements.txt"; - private static final int EOL = 10; - private static final int MINUS = 45; - private static final int SEMICOLON = 59; + private static final int MIN_CHUNK_SIZE = 1024; + private static final int MAX_NAME_LENGTH = 128; + private static final int MAX_TEMP_LENGTH = 8; public static void main(String[] args) throws IOException, InterruptedException, ExecutionException { - var result = readFile(); - - var measurements = getMeasurements(result); - - System.out.println(measurements); + System.out.println(readFile()); } - static record PartialResult(int end, Map map) { - - void merge(Map result) { - map.forEach((name, station) -> result.merge(name, station, Station::merge)); - } - } - - private static String getMeasurements(Map result) { - return result.values().stream().sorted(comparing(Station::getName)) - .map(Station::asString).collect(joining(", ", "{", "}")); - } - - private static Map readFile() throws IOException, InterruptedException, ExecutionException { - Map result = HashMap.newHashMap(10_000); + private static Map readFile() throws IOException, InterruptedException, ExecutionException { + Map result = new TreeMap<>(); try (var channel = FileChannel.open(Paths.get(FILE), StandardOpenOption.READ)) { long consumed = 0; long remaining = channel.size(); @@ -70,8 +50,11 @@ public class CalculateAverage_tonivade { var buffer = channel.map( MapMode.READ_ONLY, consumed, Math.min(remaining, Integer.MAX_VALUE)); - if (buffer.remaining() <= 1024) { - var partialResult = readChunk(buffer, 0, buffer.remaining()); + int chunks = Runtime.getRuntime().availableProcessors(); + int chunkSize = buffer.remaining() / chunks; + int leftover = buffer.remaining() % chunks; + if (chunkSize < MIN_CHUNK_SIZE) { + var partialResult = new Chunk(buffer, 0, buffer.remaining()).read(); consumed += partialResult.end(); remaining -= partialResult.end(); @@ -79,17 +62,12 @@ public class CalculateAverage_tonivade { partialResult.merge(result); } else { - var chunks = Runtime.getRuntime().availableProcessors(); - var chunksSize = buffer.remaining() / chunks; - var leftover = buffer.remaining() % chunks; - try (var scope = new StructuredTaskScope.ShutdownOnFailure()) { var tasks = new ArrayList>(chunks); for (int i = 0; i < chunks; i++) { - int start = i * chunksSize; - int length = chunksSize + (i < chunks ? leftover : 0); - tasks.add(scope.fork(() -> readChunk( - buffer, findStart(buffer, start), start + length))); + int start = i * chunkSize; + int length = chunkSize + (i < chunks ? leftover : 0); + tasks.add(scope.fork(new Chunk(buffer, start, length)::read)); } scope.join(); scope.throwIfFailed(); @@ -106,132 +84,154 @@ public class CalculateAverage_tonivade { return result; } - private static PartialResult readChunk(ByteBuffer buffer, int start, int end) { - final byte[] name = new byte[128]; - final byte[] temp = new byte[8]; - final Map map = HashMap.newHashMap(1000); - int position = start; - while (position < end) { - int semicolon = readName(buffer, position, end - position, name); - if (semicolon < 0) { - break; - } + static final class Chunk { - int endOfLine = readTemp(buffer, semicolon + 1, end - semicolon - 1, temp); - if (endOfLine < 0) { - break; - } + private static final int EOL = 10; + private static final int MINUS = 45; + private static final int SEMICOLON = 59; - map.computeIfAbsent(new Name(name, semicolon - position), Station::new) - .add(parseTemp(temp, endOfLine - semicolon - 1)); + final ByteBuffer buffer; + final int start; + final int end; - // skip end of line - position = endOfLine + 1; + final byte[] name = new byte[MAX_NAME_LENGTH]; + final byte[] temp = new byte[MAX_TEMP_LENGTH]; + final Stations stations = new Stations(); + + int hash; + + Chunk(ByteBuffer buffer, int start, int length) { + this.buffer = buffer; + this.start = findStart(buffer, start); + this.end = start + length; } - return new PartialResult(position, map); - } - private static int findStart(ByteBuffer buffer, int start) { - if (start > 0 && buffer.get(start - 1) != EOL) { - for (int i = start - 2; i > 0; i--) { - byte b = buffer.get(i); - if (b == EOL) { - return i + 1; + private static int findStart(ByteBuffer buffer, int start) { + if (start > 0 && buffer.get(start - 1) != EOL) { + for (int i = start - 2; i > 0; i--) { + byte b = buffer.get(i); + if (b == EOL) { + return i + 1; + } } } + return start; } - return start; - } - private static int readName(ByteBuffer buffer, int offset, int length, byte[] name) { - return readUntil(buffer, offset, length, name, SEMICOLON); - } + PartialResult read() { + int position = start; + while (position < end) { + int semicolon = readName(position, end - position); + if (semicolon < 0) { + break; + } - private static int readTemp(ByteBuffer buffer, int offset, int length, byte[] percentage) { - return readUntil(buffer, offset, length, percentage, EOL); - } + int endOfLine = readTemp(semicolon + 1, end - semicolon - 1); + if (endOfLine < 0) { + break; + } - private static int readUntil(ByteBuffer buffer, int offset, int length, byte[] array, int target) { - for (int i = 0; i < length; i++) { - byte b = buffer.get(i + offset); - if (b == target) { - return i + offset; + stations.find(name, semicolon - position, hash) + .add(parseTemp(temp, endOfLine - semicolon - 1)); + + // skip end of line + position = endOfLine + 1; } - array[i] = b; + return new PartialResult(position, stations.buckets); } - return -1; - } - // non null double between -99.9 (inclusive) and 99.9 (inclusive), always with one fractional digit - private static int parseTemp(byte[] value, int length) { - int period = length - 2; - if (value[0] == MINUS) { - int left = parseLeft(value, 1, period - 1); + private int readName(int offset, int length) { + hash = 1; + for (int i = 0; i < length; i++) { + byte b = buffer.get(i + offset); + if (b == SEMICOLON) { + return i + offset; + } + name[i] = b; + hash = 31 * hash + b; + } + return -1; + } + + private int readTemp(int offset, int length) { + for (int i = 0; i < length; i++) { + byte b = buffer.get(i + offset); + if (b == EOL) { + return i + offset; + } + temp[i] = b; + } + return -1; + } + + // non null double between -99.9 (inclusive) and 99.9 (inclusive), always with one fractional digit + private static int parseTemp(byte[] value, int length) { + int period = length - 2; + if (value[0] == MINUS) { + int left = parseLeft(value, 1, period - 1); + int right = toInt(value[period + 1]); + return -(left + right); + } + int left = parseLeft(value, 0, period); int right = toInt(value[period + 1]); - return -(left + right); - } - int left = parseLeft(value, 0, period); - int right = toInt(value[period + 1]); - return left + right; - } - - private static int parseLeft(byte[] value, int start, int length) { - if (length == 1) { - return toInt(value[start]) * 10; - } - // two chars - int a = toInt(value[start]) * 100; - int b = toInt(value[start + 1]) * 10; - return a + b; - } - - private static int toInt(byte c) { - return c - 48; - } - - static final class Name { - - private final byte[] value; - - Name(byte[] source, int length) { - value = new byte[length]; - System.arraycopy(source, 0, value, 0, length); + return left + right; } - @Override - public int hashCode() { - return Arrays.hashCode(value); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof Name other) { - return Arrays.equals(value, other.value); + private static int parseLeft(byte[] value, int start, int length) { + if (length == 1) { + return toInt(value[start]) * 10; } - return false; + // two chars + int a = toInt(value[start]) * 100; + int b = toInt(value[start + 1]) * 10; + return a + b; } - @Override - public String toString() { - return new String(value, StandardCharsets.UTF_8); + private static int toInt(byte c) { + return c - 48; + } + } + + static final class Stations { + + private static final int NUMBER_OF_BUCKETS = 1000; + private static final int BUCKET_SIZE = 50; + + final Station[][] buckets = new Station[NUMBER_OF_BUCKETS][BUCKET_SIZE]; + + Station find(byte[] name, int length, int hash) { + var bucket = buckets[Math.abs(hash % NUMBER_OF_BUCKETS)]; + for (int i = 0; i < BUCKET_SIZE; i++) { + if (bucket[i] == null) { + bucket[i] = new Station(name, length, hash); + return bucket[i]; + } + else if (bucket[i].sameName(length, hash)) { + return bucket[i]; + } + } + throw new IllegalStateException("no more space left"); } } static final class Station { - private final Name name; + private final byte[] name; + private final int hash; - private int min = Integer.MAX_VALUE; - private int max = Integer.MIN_VALUE; + private int min = 1000; + private int max = -1000; private int sum; private long count; - Station(Name name) { - this.name = name; + Station(byte[] source, int length, int hash) { + name = new byte[length]; + System.arraycopy(source, 0, name, 0, length); + this.hash = hash; } String getName() { - return name.toString(); + return new String(name, StandardCharsets.UTF_8); } void add(int value) { @@ -249,8 +249,13 @@ public class CalculateAverage_tonivade { return this; } - String asString() { - return name + "=" + toDouble(min) + "/" + round(mean()) + "/" + toDouble(max); + @Override + public String toString() { + return toDouble(min) + "/" + round(mean()) + "/" + toDouble(max); + } + + boolean sameName(int length, int hash) { + return name.length == length && this.hash == hash; } private double mean() { @@ -265,4 +270,17 @@ public class CalculateAverage_tonivade { return Math.round(value * 10.) / 10.; } } + + static record PartialResult(int end, Station[][] stations) { + + void merge(Map result) { + for (Station[] bucket : stations) { + for (Station station : bucket) { + if (station != null) { + result.merge(station.getName(), station, Station::merge); + } + } + } + } + } }