diff --git a/calculate_average_merykitty.sh b/calculate_average_merykitty.sh index 1e944da..9183e59 100755 --- a/calculate_average_merykitty.sh +++ b/calculate_average_merykitty.sh @@ -16,5 +16,5 @@ # -JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector" # -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_merykitty::iterate" +JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector -XX:-TieredCompilation" # -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_merykitty::iterate" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_merykitty diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java b/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java index e86ecee..1f5acf3 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java @@ -25,8 +25,6 @@ import java.nio.channels.FileChannel.MapMode; import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.Arrays; -import java.util.HashMap; import java.util.Map; import java.util.TreeMap; import jdk.incubator.vector.ByteVector; @@ -35,13 +33,21 @@ import jdk.incubator.vector.VectorSpecies; public class CalculateAverage_merykitty { private static final String FILE = "./measurements.txt"; - private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_PREFERRED; + private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_PREFERRED.length() >= 32 + ? ByteVector.SPECIES_256 + : ByteVector.SPECIES_128; private static final ValueLayout.OfLong JAVA_LONG_LT = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); private static final long KEY_MAX_SIZE = 100; - private record ResultRow(double min, double mean, double max) { + private static class Aggregator { + private int keySize; + private long min = Integer.MAX_VALUE; + private long max = Integer.MIN_VALUE; + private long sum; + private long count; + public String toString() { - return round(min) + "/" + round(mean) + "/" + round(max); + return round(min / 10.) + "/" + round(sum / (double) (10 * count)) + "/" + round(max / 10.); } private double round(double value) { @@ -49,96 +55,100 @@ public class CalculateAverage_merykitty { } } - private static class Aggregator { - private long min = Integer.MAX_VALUE; - private long max = Integer.MIN_VALUE; - private long sum; - private long count; - } - // An open-address map that is specialized for this task private static class PoorManMap { - static final int R_LOAD_FACTOR = 2; - private static class PoorManMapNode { - byte[] data; - long size; - int hash; - Aggregator aggr; + // 100-byte key + 4-byte hash + 4-byte size + + // 2-byte min + 2-byte max + 8-byte sum + 8-byte count + private static final int KEY_SIZE = 128; - PoorManMapNode(MemorySegment data, long offset, long size, int hash) { - this.hash = hash; - this.size = size; - this.data = new byte[BYTE_SPECIES.vectorByteSize() + (int) KEY_MAX_SIZE]; - this.aggr = new Aggregator(); - MemorySegment.copy(data, offset, MemorySegment.ofArray(this.data), BYTE_SPECIES.vectorByteSize(), size); + // There is an assumption that map size <= 10000; + private static final int CAPACITY = 1 << 17; + private static final int BUCKET_MASK = CAPACITY - 1; + + byte[] keyData; + Aggregator[] nodes; + + PoorManMap() { + this.keyData = new byte[CAPACITY * KEY_SIZE]; + this.nodes = new Aggregator[CAPACITY]; + } + + void observe(Aggregator node, long value) { + node.min = Math.min(node.min, value); + node.max = Math.max(node.max, value); + node.sum += value; + node.count++; + } + + Aggregator indexSimple(MemorySegment data, long offset, int size) { + int x; + int y; + if (size >= Integer.BYTES) { + x = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset); + y = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset + size - Integer.BYTES); } - } - - MemorySegment data; - PoorManMapNode[] nodes; - int size; - - PoorManMap(MemorySegment data) { - this.data = data; - this.nodes = new PoorManMapNode[1 << 10]; - } - - Aggregator indexSimple(long offset, long size, int hash) { - hash = rehash(hash); - int bucketMask = nodes.length - 1; - int bucket = hash & bucketMask; - for (;; bucket = (bucket + 1) & bucketMask) { - PoorManMapNode node = nodes[bucket]; + else { + x = data.get(ValueLayout.JAVA_BYTE, offset); + y = data.get(ValueLayout.JAVA_BYTE, offset + size - Byte.BYTES); + } + int hash = hash(x, y); + int bucket = hash & BUCKET_MASK; + for (;; bucket = (bucket + 1) & BUCKET_MASK) { + var node = this.nodes[bucket]; if (node == null) { - this.size++; - if (this.size * R_LOAD_FACTOR > nodes.length) { - grow(); - bucketMask = nodes.length - 1; - for (bucket = hash & bucketMask; nodes[bucket] != null; bucket = (bucket + 1) & bucketMask) { - } - } - node = new PoorManMapNode(this.data, offset, size, hash); - nodes[bucket] = node; - return node.aggr; + return insertInto(bucket, data, offset, size); } - else if (keyEqualScalar(node, offset, size, hash)) { - return node.aggr; + else if (keyEqualScalar(bucket, data, offset, size)) { + return node; } } } - void grow() { - var oldNodes = this.nodes; - var newNodes = new PoorManMapNode[oldNodes.length * 2]; - int bucketMask = newNodes.length - 1; - for (var node : oldNodes) { + Aggregator insertInto(int bucket, MemorySegment data, long offset, int size) { + var node = new Aggregator(); + node.keySize = size; + this.nodes[bucket] = node; + MemorySegment.copy(data, offset, MemorySegment.ofArray(this.keyData), (long) bucket * KEY_SIZE, size); + return node; + } + + void mergeInto(Map target) { + for (int i = 0; i < CAPACITY; i++) { + var node = this.nodes[i]; if (node == null) { continue; } - int bucket = node.hash & bucketMask; - for (; newNodes[bucket] != null; bucket = (bucket + 1) & bucketMask) { - } - newNodes[bucket] = node; + + String key = new String(this.keyData, i * KEY_SIZE, node.keySize, StandardCharsets.UTF_8); + target.compute(key, (k, v) -> { + if (v == null) { + v = new Aggregator(); + } + + v.min = Math.min(v.min, node.min); + v.max = Math.max(v.max, node.max); + v.sum += node.sum; + v.count += node.count; + return v; + }); } - this.nodes = newNodes; } - static int rehash(int x) { - x = ((x >>> 16) ^ x) * 0x45d9f3b; - x = ((x >>> 16) ^ x) * 0x45d9f3b; - x = (x >>> 16) ^ x; - return x; + static int hash(int x, int y) { + int seed = 0x9E3779B9; + int rotate = 5; + return (Integer.rotateLeft(x * seed, rotate) ^ y) * seed; // FxHash } - private boolean keyEqualScalar(PoorManMapNode node, long offset, long size, int hash) { - if (node.hash != hash || node.size != size) { + private boolean keyEqualScalar(int bucket, MemorySegment data, long offset, int size) { + if (this.nodes[bucket].keySize != size) { return false; } // Be simple for (int i = 0; i < size; i++) { - int c1 = node.data[BYTE_SPECIES.vectorByteSize() + i]; + int c1 = this.keyData[bucket * KEY_SIZE + i]; int c2 = data.get(ValueLayout.JAVA_BYTE, offset + i); if (c1 != c2) { return false; @@ -152,7 +162,7 @@ public class CalculateAverage_merykitty { // 1 - 2 digits to the left and 1 digits to the right of the separator to a // fix-precision format. It returns the offset of the next line (presumably followed // the final digit and a '\n') - private static long parseDataPoint(Aggregator aggr, MemorySegment data, long offset) { + private static long parseDataPoint(PoorManMap aggrMap, Aggregator node, MemorySegment data, long offset) { long word = data.get(JAVA_LONG_LT, offset); // The 4th binary digit of the ascii of a digit is 1 while // that of the '.' is 0. This finds the decimal separator @@ -176,16 +186,13 @@ public class CalculateAverage_merykitty { // That was close :) long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; long value = (absValue ^ signed) - signed; - aggr.min = Math.min(value, aggr.min); - aggr.max = Math.max(value, aggr.max); - aggr.sum += value; - aggr.count++; + aggrMap.observe(node, value); return offset + (decimalSepPos >>> 3) + 3; } // Tail processing version of the above, do not over-fetch and be simple - private static long parseDataPointTail(Aggregator aggr, MemorySegment data, long offset) { - int point = 0; + private static long parseDataPointSimple(PoorManMap aggrMap, Aggregator node, MemorySegment data, long offset) { + int value = 0; boolean negative = false; if (data.get(ValueLayout.JAVA_BYTE, offset) == '-') { negative = true; @@ -195,110 +202,80 @@ public class CalculateAverage_merykitty { int c = data.get(ValueLayout.JAVA_BYTE, offset); if (c == '.') { c = data.get(ValueLayout.JAVA_BYTE, offset + 1); - point = point * 10 + (c - '0'); + value = value * 10 + (c - '0'); offset += 3; break; } - point = point * 10 + (c - '0'); + value = value * 10 + (c - '0'); } - point = negative ? -point : point; - aggr.min = Math.min(point, aggr.min); - aggr.max = Math.max(point, aggr.max); - aggr.sum += point; - aggr.count++; + value = negative ? -value : value; + aggrMap.observe(node, value); return offset; } - // An iteration of the main parse loop, parse some lines starting from offset. - // This requires offset to be the start of a line and there is spare space so + // An iteration of the main parse loop, parse a line starting from offset. + // This requires offset to be the start of the line and there is spare space so // that we have relative freedom in processing - // It returns the offset of the next line that it needs to be processed + // It returns the offset of the next line that it needs processing private static long iterate(PoorManMap aggrMap, MemorySegment data, long offset) { - // This method fetches a segment of the file starting from offset and returns after - // finishing processing that segment var line = ByteVector.fromMemorySegment(BYTE_SPECIES, data, offset, ByteOrder.nativeOrder()); // Find the delimiter ';' - long semicolons = line.compare(VectorOperators.EQ, ';').toLong(); + int keySize = line.compare(VectorOperators.EQ, ';').firstTrue(); - // If we cannot find the delimiter in the current segment, that means the key is - // longer than the segment, fall back to scalar processing - if (semicolons == 0) { - long semicolonPos = BYTE_SPECIES.vectorByteSize(); - for (; data.get(ValueLayout.JAVA_BYTE, offset + semicolonPos) != ';'; semicolonPos++) { + // If we cannot find the delimiter in the vector, that means the key is + // longer than the vector, fall back to scalar processing + if (keySize == BYTE_SPECIES.vectorByteSize()) { + while (data.get(ValueLayout.JAVA_BYTE, offset + keySize) != ';') { + keySize++; } - int hash = line.reinterpretAsInts().lane(0); - var aggr = aggrMap.indexSimple(offset, semicolonPos, hash); - return parseDataPoint(aggr, data, offset + 1 + semicolonPos); + var node = aggrMap.indexSimple(data, offset, keySize); + return parseDataPoint(aggrMap, node, data, offset + 1 + keySize); } - long currOffset = offset; - while (true) { - // Process line by line, currOffset is the offset of the current line in - // the file, localOffset is the offset of the current line with respect - // to the start of the iteration segment - int localOffset = (int) (currOffset - offset); - - // The key length - long semicolonPos = Long.numberOfTrailingZeros(semicolons) - localOffset; - int hash = data.get(ValueLayout.JAVA_INT_UNALIGNED, currOffset); - if (semicolonPos < Integer.BYTES) { - hash = (byte) hash; - } - - // We inline the searching of the value in the hash map - Aggregator aggr; - hash = PoorManMap.rehash(hash); - int bucketMask = aggrMap.nodes.length - 1; - int bucket = hash & bucketMask; - for (;; bucket = (bucket + 1) & bucketMask) { - PoorManMap.PoorManMapNode node = aggrMap.nodes[bucket]; - if (node == null) { - aggrMap.size++; - if (aggrMap.size * PoorManMap.R_LOAD_FACTOR > aggrMap.nodes.length) { - aggrMap.grow(); - bucketMask = aggrMap.nodes.length - 1; - for (bucket = hash & bucketMask; aggrMap.nodes[bucket] != null; bucket = (bucket + 1) & bucketMask) { - } - } - node = new PoorManMap.PoorManMapNode(data, currOffset, semicolonPos, hash); - aggrMap.nodes[bucket] = node; - aggr = node.aggr; - break; - } - - if (node.hash != hash || node.size != semicolonPos) { - continue; - } - - // The technique here is to align the key in both vectors so that we can do an - // element-wise comparison and check if all characters match - var nodeKey = ByteVector.fromArray(BYTE_SPECIES, node.data, BYTE_SPECIES.length() - localOffset); - var eqMask = line.compare(VectorOperators.EQ, nodeKey).toLong(); - long validMask = (-1L >>> -semicolonPos) << localOffset; - if ((eqMask & validMask) == validMask) { - aggr = node.aggr; - break; - } - } - - long nextOffset = parseDataPoint(aggr, data, currOffset + 1 + semicolonPos); - semicolons &= (semicolons - 1); - if (semicolons == 0) { - return nextOffset; - } - currOffset = nextOffset; + // We inline the searching of the value in the hash map + int x; + int y; + if (keySize >= Integer.BYTES) { + x = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset); + y = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset + keySize - Integer.BYTES); } + else { + x = data.get(ValueLayout.JAVA_BYTE, offset); + y = data.get(ValueLayout.JAVA_BYTE, offset + keySize - Byte.BYTES); + } + int hash = PoorManMap.hash(x, y); + int bucket = hash & PoorManMap.BUCKET_MASK; + Aggregator node; + for (;; bucket = (bucket + 1) & PoorManMap.BUCKET_MASK) { + node = aggrMap.nodes[bucket]; + if (node == null) { + node = aggrMap.insertInto(bucket, data, offset, keySize); + break; + } + if (node.keySize != keySize) { + continue; + } + + var nodeKey = ByteVector.fromArray(BYTE_SPECIES, aggrMap.keyData, bucket * PoorManMap.KEY_SIZE); + long eqMask = line.compare(VectorOperators.EQ, nodeKey).toLong(); + long validMask = -1L >>> -keySize; + if ((eqMask & validMask) == validMask) { + break; + } + } + + return parseDataPoint(aggrMap, node, data, offset + keySize + 1); } // Process all lines that start in [offset, limit) private static PoorManMap processFile(MemorySegment data, long offset, long limit) { - var aggrMap = new PoorManMap(data); + var aggrMap = new PoorManMap(); // Find the start of a new line if (offset != 0) { offset--; - for (; offset < limit;) { + while (offset < limit) { if (data.get(ValueLayout.JAVA_BYTE, offset++) == '\n') { break; } @@ -318,18 +295,12 @@ public class CalculateAverage_merykitty { // Now we are at the tail, just be simple while (offset < limit) { - long semicolonPos = 0; - for (; data.get(ValueLayout.JAVA_BYTE, offset + semicolonPos) != ';'; semicolonPos++) { + int keySize = 0; + while (data.get(ValueLayout.JAVA_BYTE, offset + keySize) != ';') { + keySize++; } - int hash; - if (semicolonPos >= Integer.BYTES) { - hash = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset); - } - else { - hash = data.get(ValueLayout.JAVA_BYTE, offset); - } - var aggr = aggrMap.indexSimple(offset, semicolonPos, hash); - offset = parseDataPointTail(aggr, data, offset + 1 + semicolonPos); + var node = aggrMap.indexSimple(data, offset, keySize); + offset = parseDataPointSimple(aggrMap, node, data, offset + 1 + keySize); } return aggrMap; @@ -337,7 +308,7 @@ public class CalculateAverage_merykitty { public static void main(String[] args) throws InterruptedException, IOException { int processorCnt = Runtime.getRuntime().availableProcessors(); - var res = HashMap. newHashMap(processorCnt); + var res = new TreeMap(); try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); var arena = Arena.ofShared()) { var data = file.map(MapMode.READ_ONLY, 0, file.size(), arena); @@ -348,9 +319,7 @@ public class CalculateAverage_merykitty { int index = i; long offset = i * chunkSize; long limit = Math.min((i + 1) * chunkSize, data.byteSize()); - var thread = new Thread(() -> { - resultList[index] = processFile(data, offset, limit); - }); + var thread = new Thread(() -> resultList[index] = processFile(data, offset, limit)); threadList[index] = thread; thread.start(); } @@ -360,32 +329,10 @@ public class CalculateAverage_merykitty { // Collect the results for (var aggrMap : resultList) { - for (var node : aggrMap.nodes) { - if (node == null) { - continue; - } - byte[] keyData = Arrays.copyOfRange(node.data, BYTE_SPECIES.vectorByteSize(), BYTE_SPECIES.vectorByteSize() + (int) node.size); - String key = new String(keyData, StandardCharsets.UTF_8); - var aggr = node.aggr; - var resAggr = new Aggregator(); - var existingAggr = res.putIfAbsent(key, resAggr); - if (existingAggr != null) { - resAggr = existingAggr; - } - resAggr.min = Math.min(resAggr.min, aggr.min); - resAggr.max = Math.max(resAggr.max, aggr.max); - resAggr.sum += aggr.sum; - resAggr.count += aggr.count; - } + aggrMap.mergeInto(res); } } - Map measurements = new TreeMap<>(); - for (var entry : res.entrySet()) { - String key = entry.getKey(); - var aggr = entry.getValue(); - measurements.put(key, new ResultRow((double) aggr.min / 10, (double) aggr.sum / (aggr.count * 10), (double) aggr.max / 10)); - } - System.out.println(measurements); + System.out.println(res); } }