diff --git a/calculate_average_merykitty.sh b/calculate_average_merykitty.sh new file mode 100755 index 0000000..7cee051 --- /dev/null +++ b/calculate_average_merykitty.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector" # -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_merykitty::iterate" +time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_merykitty diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java b/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java new file mode 100644 index 0000000..e86ecee --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java @@ -0,0 +1,391 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; + +public class CalculateAverage_merykitty { + private static final String FILE = "./measurements.txt"; + private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_PREFERRED; + private static final ValueLayout.OfLong JAVA_LONG_LT = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + private static final long KEY_MAX_SIZE = 100; + + private record ResultRow(double min, double mean, double max) { + public String toString() { + return round(min) + "/" + round(mean) + "/" + round(max); + } + + private double round(double value) { + return Math.round(value * 10.0) / 10.0; + } + } + + private static class Aggregator { + private long min = Integer.MAX_VALUE; + private long max = Integer.MIN_VALUE; + private long sum; + private long count; + } + + // An open-address map that is specialized for this task + private static class PoorManMap { + static final int R_LOAD_FACTOR = 2; + + private static class PoorManMapNode { + byte[] data; + long size; + int hash; + Aggregator aggr; + + PoorManMapNode(MemorySegment data, long offset, long size, int hash) { + this.hash = hash; + this.size = size; + this.data = new byte[BYTE_SPECIES.vectorByteSize() + (int) KEY_MAX_SIZE]; + this.aggr = new Aggregator(); + MemorySegment.copy(data, offset, MemorySegment.ofArray(this.data), BYTE_SPECIES.vectorByteSize(), size); + } + } + + MemorySegment data; + PoorManMapNode[] nodes; + int size; + + PoorManMap(MemorySegment data) { + this.data = data; + this.nodes = new PoorManMapNode[1 << 10]; + } + + Aggregator indexSimple(long offset, long size, int hash) { + hash = rehash(hash); + int bucketMask = nodes.length - 1; + int bucket = hash & bucketMask; + for (;; bucket = (bucket + 1) & bucketMask) { + PoorManMapNode node = nodes[bucket]; + if (node == null) { + this.size++; + if (this.size * R_LOAD_FACTOR > nodes.length) { + grow(); + bucketMask = nodes.length - 1; + for (bucket = hash & bucketMask; nodes[bucket] != null; bucket = (bucket + 1) & bucketMask) { + } + } + node = new PoorManMapNode(this.data, offset, size, hash); + nodes[bucket] = node; + return node.aggr; + } + else if (keyEqualScalar(node, offset, size, hash)) { + return node.aggr; + } + } + } + + void grow() { + var oldNodes = this.nodes; + var newNodes = new PoorManMapNode[oldNodes.length * 2]; + int bucketMask = newNodes.length - 1; + for (var node : oldNodes) { + if (node == null) { + continue; + } + int bucket = node.hash & bucketMask; + for (; newNodes[bucket] != null; bucket = (bucket + 1) & bucketMask) { + } + newNodes[bucket] = node; + } + this.nodes = newNodes; + } + + static int rehash(int x) { + x = ((x >>> 16) ^ x) * 0x45d9f3b; + x = ((x >>> 16) ^ x) * 0x45d9f3b; + x = (x >>> 16) ^ x; + return x; + } + + private boolean keyEqualScalar(PoorManMapNode node, long offset, long size, int hash) { + if (node.hash != hash || node.size != size) { + return false; + } + + // Be simple + for (int i = 0; i < size; i++) { + int c1 = node.data[BYTE_SPECIES.vectorByteSize() + i]; + int c2 = data.get(ValueLayout.JAVA_BYTE, offset + i); + if (c1 != c2) { + return false; + } + } + return true; + } + } + + // Parse a number that may/may not contain a minus sign followed by a decimal with + // 1 - 2 digits to the left and 1 digits to the right of the separator to a + // fix-precision format. It returns the offset of the next line (presumably followed + // the final digit and a '\n') + private static long parseDataPoint(Aggregator aggr, MemorySegment data, long offset) { + long word = data.get(JAVA_LONG_LT, offset); + // The 4th binary digit of the ascii of a digit is 1 while + // that of the '.' is 0. This finds the decimal separator + // The value can be 12, 20, 28 + int decimalSepPos = Long.numberOfTrailingZeros(~word & 0x10101000); + int shift = 28 - decimalSepPos; + // signed is -1 if negative, 0 otherwise + long signed = (~word << 59) >> 63; + long designMask = ~(signed & 0xFF); + // Align the number to a specific position and transform the ascii code + // to actual digit value in each byte + long digits = ((word & designMask) << shift) & 0x0F000F0F00L; + + // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit) + // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) = + // 0x000000UU00TTHH00 + + // 0x00UU00TTHH000000 * 10 + + // 0xUU00TTHH00000000 * 100 + // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400 + // This results in our value lies in the bit 32 to 41 of this product + // That was close :) + long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; + long value = (absValue ^ signed) - signed; + aggr.min = Math.min(value, aggr.min); + aggr.max = Math.max(value, aggr.max); + aggr.sum += value; + aggr.count++; + return offset + (decimalSepPos >>> 3) + 3; + } + + // Tail processing version of the above, do not over-fetch and be simple + private static long parseDataPointTail(Aggregator aggr, MemorySegment data, long offset) { + int point = 0; + boolean negative = false; + if (data.get(ValueLayout.JAVA_BYTE, offset) == '-') { + negative = true; + offset++; + } + for (;; offset++) { + int c = data.get(ValueLayout.JAVA_BYTE, offset); + if (c == '.') { + c = data.get(ValueLayout.JAVA_BYTE, offset + 1); + point = point * 10 + (c - '0'); + offset += 3; + break; + } + + point = point * 10 + (c - '0'); + } + point = negative ? -point : point; + aggr.min = Math.min(point, aggr.min); + aggr.max = Math.max(point, aggr.max); + aggr.sum += point; + aggr.count++; + return offset; + } + + // An iteration of the main parse loop, parse some lines starting from offset. + // This requires offset to be the start of a line and there is spare space so + // that we have relative freedom in processing + // It returns the offset of the next line that it needs to be processed + private static long iterate(PoorManMap aggrMap, MemorySegment data, long offset) { + // This method fetches a segment of the file starting from offset and returns after + // finishing processing that segment + var line = ByteVector.fromMemorySegment(BYTE_SPECIES, data, offset, ByteOrder.nativeOrder()); + + // Find the delimiter ';' + long semicolons = line.compare(VectorOperators.EQ, ';').toLong(); + + // If we cannot find the delimiter in the current segment, that means the key is + // longer than the segment, fall back to scalar processing + if (semicolons == 0) { + long semicolonPos = BYTE_SPECIES.vectorByteSize(); + for (; data.get(ValueLayout.JAVA_BYTE, offset + semicolonPos) != ';'; semicolonPos++) { + } + int hash = line.reinterpretAsInts().lane(0); + var aggr = aggrMap.indexSimple(offset, semicolonPos, hash); + return parseDataPoint(aggr, data, offset + 1 + semicolonPos); + } + + long currOffset = offset; + while (true) { + // Process line by line, currOffset is the offset of the current line in + // the file, localOffset is the offset of the current line with respect + // to the start of the iteration segment + int localOffset = (int) (currOffset - offset); + + // The key length + long semicolonPos = Long.numberOfTrailingZeros(semicolons) - localOffset; + int hash = data.get(ValueLayout.JAVA_INT_UNALIGNED, currOffset); + if (semicolonPos < Integer.BYTES) { + hash = (byte) hash; + } + + // We inline the searching of the value in the hash map + Aggregator aggr; + hash = PoorManMap.rehash(hash); + int bucketMask = aggrMap.nodes.length - 1; + int bucket = hash & bucketMask; + for (;; bucket = (bucket + 1) & bucketMask) { + PoorManMap.PoorManMapNode node = aggrMap.nodes[bucket]; + if (node == null) { + aggrMap.size++; + if (aggrMap.size * PoorManMap.R_LOAD_FACTOR > aggrMap.nodes.length) { + aggrMap.grow(); + bucketMask = aggrMap.nodes.length - 1; + for (bucket = hash & bucketMask; aggrMap.nodes[bucket] != null; bucket = (bucket + 1) & bucketMask) { + } + } + node = new PoorManMap.PoorManMapNode(data, currOffset, semicolonPos, hash); + aggrMap.nodes[bucket] = node; + aggr = node.aggr; + break; + } + + if (node.hash != hash || node.size != semicolonPos) { + continue; + } + + // The technique here is to align the key in both vectors so that we can do an + // element-wise comparison and check if all characters match + var nodeKey = ByteVector.fromArray(BYTE_SPECIES, node.data, BYTE_SPECIES.length() - localOffset); + var eqMask = line.compare(VectorOperators.EQ, nodeKey).toLong(); + long validMask = (-1L >>> -semicolonPos) << localOffset; + if ((eqMask & validMask) == validMask) { + aggr = node.aggr; + break; + } + } + + long nextOffset = parseDataPoint(aggr, data, currOffset + 1 + semicolonPos); + semicolons &= (semicolons - 1); + if (semicolons == 0) { + return nextOffset; + } + currOffset = nextOffset; + } + } + + // Process all lines that start in [offset, limit) + private static PoorManMap processFile(MemorySegment data, long offset, long limit) { + var aggrMap = new PoorManMap(data); + // Find the start of a new line + if (offset != 0) { + offset--; + for (; offset < limit;) { + if (data.get(ValueLayout.JAVA_BYTE, offset++) == '\n') { + break; + } + } + } + + // If there is no line starting in this segment, just return + if (offset == limit) { + return aggrMap; + } + + // The main loop, optimized for speed + while (offset < limit - Math.max(BYTE_SPECIES.vectorByteSize(), + Long.BYTES + 1 + KEY_MAX_SIZE)) { + offset = iterate(aggrMap, data, offset); + } + + // Now we are at the tail, just be simple + while (offset < limit) { + long semicolonPos = 0; + for (; data.get(ValueLayout.JAVA_BYTE, offset + semicolonPos) != ';'; semicolonPos++) { + } + int hash; + if (semicolonPos >= Integer.BYTES) { + hash = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset); + } + else { + hash = data.get(ValueLayout.JAVA_BYTE, offset); + } + var aggr = aggrMap.indexSimple(offset, semicolonPos, hash); + offset = parseDataPointTail(aggr, data, offset + 1 + semicolonPos); + } + + return aggrMap; + } + + public static void main(String[] args) throws InterruptedException, IOException { + int processorCnt = Runtime.getRuntime().availableProcessors(); + var res = HashMap. newHashMap(processorCnt); + try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); + var arena = Arena.ofShared()) { + var data = file.map(MapMode.READ_ONLY, 0, file.size(), arena); + long chunkSize = Math.ceilDiv(data.byteSize(), processorCnt); + var threadList = new Thread[processorCnt]; + var resultList = new PoorManMap[processorCnt]; + for (int i = 0; i < processorCnt; i++) { + int index = i; + long offset = i * chunkSize; + long limit = Math.min((i + 1) * chunkSize, data.byteSize()); + var thread = new Thread(() -> { + resultList[index] = processFile(data, offset, limit); + }); + threadList[index] = thread; + thread.start(); + } + for (var thread : threadList) { + thread.join(); + } + + // Collect the results + for (var aggrMap : resultList) { + for (var node : aggrMap.nodes) { + if (node == null) { + continue; + } + byte[] keyData = Arrays.copyOfRange(node.data, BYTE_SPECIES.vectorByteSize(), BYTE_SPECIES.vectorByteSize() + (int) node.size); + String key = new String(keyData, StandardCharsets.UTF_8); + var aggr = node.aggr; + var resAggr = new Aggregator(); + var existingAggr = res.putIfAbsent(key, resAggr); + if (existingAggr != null) { + resAggr = existingAggr; + } + resAggr.min = Math.min(resAggr.min, aggr.min); + resAggr.max = Math.max(resAggr.max, aggr.max); + resAggr.sum += aggr.sum; + resAggr.count += aggr.count; + } + } + } + + Map measurements = new TreeMap<>(); + for (var entry : res.entrySet()) { + String key = entry.getKey(); + var aggr = entry.getValue(); + measurements.put(key, new ResultRow((double) aggr.min / 10, (double) aggr.sum / (aggr.count * 10), (double) aggr.max / 10)); + } + System.out.println(measurements); + } +}