From 8c248714061535a819181ca137876fa1435a7e7d Mon Sep 17 00:00:00 2001 From: Roy van Rijn Date: Thu, 11 Jan 2024 11:12:05 +0100 Subject: [PATCH] Fixing the off-by-one error and updating to native, redone layout of code. (#307) --- calculate_average_royvanrijn.sh | 15 +- prepare_royvanrijn.sh | 9 + .../onebrc/CalculateAverage_royvanrijn.java | 384 ++++++++---------- 3 files changed, 196 insertions(+), 212 deletions(-) diff --git a/calculate_average_royvanrijn.sh b/calculate_average_royvanrijn.sh index 2a24bfa..6931bf0 100755 --- a/calculate_average_royvanrijn.sh +++ b/calculate_average_royvanrijn.sh @@ -15,5 +15,16 @@ # limitations under the License. # -JAVA_OPTS="--enable-preview" -java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_royvanrijn +if [ -f target/CalculateAverage_royvanrijn_image ]; then + echo "Picking up existing native image 'target/CalculateAverage_royvanrijn_image', delete the file to select JVM mode." 1>&2 + target/CalculateAverage_royvanrijn_image +else + JAVA_OPTS="--enable-preview -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA" + if [[ ! "$(uname -s)" = "Darwin" ]]; then + # On OS/X, my machine, this errors: + JAVA_OPTS="$JAVA_OPTS -XX:+UseTransparentHugePages" + fi + echo "Choosing to run the app in JVM mode as no native image was found, use additional_build_step_royvanrijn.sh to generate." 1>&2 + java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_royvanrijn +fi + diff --git a/prepare_royvanrijn.sh b/prepare_royvanrijn.sh index f83a3ff..2088b7b 100755 --- a/prepare_royvanrijn.sh +++ b/prepare_royvanrijn.sh @@ -17,3 +17,12 @@ source "$HOME/.sdkman/bin/sdkman-init.sh" sdk use java 21.0.1-graal 1>&2 + +# ./mvnw clean verify removes target/ and will re-trigger native image creation. +if [ ! -f target/CalculateAverage_royvanrijn_image ]; then + + JAVA_OPTS="--enable-preview -dsa" + NATIVE_IMAGE_OPTS="--gc=epsilon -Ob -O3 -march=native --strict-image-heap $JAVA_OPTS" + + native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_royvanrijn_image dev.morling.onebrc.CalculateAverage_royvanrijn +fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java index aa22bef..4cf9925 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java @@ -18,52 +18,52 @@ package dev.morling.onebrc; import java.io.IOException; import java.lang.foreign.Arena; import java.lang.reflect.Field; -import java.nio.ByteOrder; import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.Arrays; -import java.util.Objects; -import java.util.TreeMap; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; -import java.util.stream.Stream; import sun.misc.Unsafe; /** * Changelog: * - * Initial submission: 62000 ms - * Chunked reader: 16000 ms - * Optimized parser: 13000 ms - * Branchless methods: 11000 ms - * Adding memory mapped files: 6500 ms (based on bjhara's submission) - * Skipping string creation: 4700 ms - * Custom hashmap... 4200 ms - * Added SWAR token checks: 3900 ms - * Skipped String creation: 3500 ms (idea from kgonia) - * Improved String skip: 3250 ms - * Segmenting files: 3150 ms (based on spullara's code) - * Not using SWAR for EOL: 2850 ms - * Inlining hash calculation: 2450 ms - * Replacing branchless code: 2200 ms (sometimes we need to kill the things we love) - * Added unsafe memory access: 1900 ms (keeping the long[] small and local) - * - * Best performing JVM on MacBook M2 Pro: 21.0.1-graal - * `sdk use java 21.0.1-graal` + * Initial submission: 62000 ms + * Chunked reader: 16000 ms + * Optimized parser: 13000 ms + * Branchless methods: 11000 ms + * Adding memory mapped files: 6500 ms (based on bjhara's submission) + * Skipping string creation: 4700 ms + * Custom hashmap... 4200 ms + * Added SWAR token checks: 3900 ms + * Skipped String creation: 3500 ms (idea from kgonia) + * Improved String skip: 3250 ms + * Segmenting files: 3150 ms (based on spullara's code) + * Not using SWAR for EOL: 2850 ms + * Inlining hash calculation: 2450 ms + * Replacing branchless code: 2200 ms (sometimes we need to kill the things we love) + * Added unsafe memory access: 1900 ms (keeping the long[] small and local) + * Fixed bug, UNSAFE bytes String: 1850 ms + * Separate hash from entries: 1550 ms + * Various tweaks for Linux/cache 1550 ms (should/could make a difference on target machine) + * Improved layout/predictability 1450 ms (on par with Thomas Wuerthinger) * + * Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai for ideas. */ public class CalculateAverage_royvanrijn { private static final String FILE = "./measurements.txt"; private static final Unsafe UNSAFE = initUnsafe(); - private static final boolean isBigEndian = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); private static Unsafe initUnsafe() { try { - Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); theUnsafe.setAccessible(true); return (Unsafe) theUnsafe.get(Unsafe.class); } @@ -73,32 +73,42 @@ public class CalculateAverage_royvanrijn { } public static void main(String[] args) throws Exception { - new CalculateAverage_royvanrijn().run(); - } - - public void run() throws Exception { // Calculate input segments. - int numberOfChunks = Runtime.getRuntime().availableProcessors(); - long[] chunks = getSegments(numberOfChunks); + final int numberOfChunks = Runtime.getRuntime().availableProcessors(); + final long[] chunks = getSegments(numberOfChunks); - // Parallel processing of segments. - TreeMap results = IntStream.range(0, chunks.length - 1) - .mapToObj(chunkIndex -> process(chunks[chunkIndex], chunks[chunkIndex + 1])).parallel() - .flatMap(MeasurementRepository::get) - .collect(Collectors.toMap(e -> e.city, MeasurementRepository.Entry::measurement, Measurement::updateWith, TreeMap::new)); + final List repositories = IntStream.range(0, chunks.length - 1) + .mapToObj(chunkIndex -> processMemoryArea(chunks[chunkIndex], chunks[chunkIndex + 1])) + .parallel() + .toList(); + + // Sometimes simple is better: + final HashMap measurements = HashMap.newHashMap(1 << 10); + for (Entry[] entries : repositories) { + for (Entry entry : entries) { + if (entry != null) + measurements.merge(entry.city, entry, Entry::mergeWith); + } + } + + System.out.print("{" + + measurements.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", "))); + System.out.println("}"); - System.out.println(results); } - private static long[] getSegments(int numberOfChunks) throws IOException { + /** + * Simpler way to get the segments and launch parallel processing by thomaswue + */ + private static long[] getSegments(final int numberOfChunks) throws IOException { try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { - long fileSize = fileChannel.size(); - long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks; - long[] chunks = new long[numberOfChunks + 1]; - long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); + final long fileSize = fileChannel.size(); + final long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks; + final long[] chunks = new long[numberOfChunks + 1]; + final long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); chunks[0] = mappedAddress; - long endAddress = mappedAddress + fileSize; + final long endAddress = mappedAddress + fileSize; for (int i = 1; i < numberOfChunks; ++i) { long chunkAddress = mappedAddress + i * segmentSize; // Align to first row start. @@ -112,108 +122,36 @@ public class CalculateAverage_royvanrijn { } } - private MeasurementRepository process(long fromAddress, long toAddress) { + private static final int TABLE_SIZE = 1 << 18; // large enough for the contest. + private static final int TABLE_MASK = (TABLE_SIZE - 1); - MeasurementRepository repository = new MeasurementRepository(); - long ptr = fromAddress; - long[] dataBuffer = new long[16]; - while ((ptr = processEntity(dataBuffer, ptr, toAddress, repository)) < toAddress) - ; + static final class Entry { + private final long[] data; + private final String city; + private int min, max, count; + private long sum; - return repository; - } - - private static final long SEPARATOR_PATTERN = compilePattern((byte) ';'); - - /** - * Already looping the longs here, lets shoehorn in making a hash - */ - private long processEntity(final long[] data, final long start, final long limit, final MeasurementRepository measurementRepository) { - int hash = 1; - long i; - int dataPtr = 0; - for (i = start; i <= limit - 8; i += 8) { - long word = UNSAFE.getLong(i); - if (isBigEndian) { - word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this - } - final long match = word ^ SEPARATOR_PATTERN; - long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L; - - if (mask != 0) { - - final long partialWord = word & ((mask >> 7) - 1); - hash = longHashStep(hash, partialWord); - data[dataPtr] = partialWord; - - final int index = Long.numberOfTrailingZeros(mask) >> 3; - return process(start, i + index, hash, data, measurementRepository); - } - data[dataPtr++] = word; - hash = longHashStep(hash, word); - } - // Handle remaining bytes near the limit of the buffer: - long partialWord = 0; - int len = 0; - for (; i < limit; i++) { - byte read; - if ((read = UNSAFE.getByte(i)) == ';') { - hash = longHashStep(hash, partialWord); - data[dataPtr] = partialWord; - return process(start, i, hash, data, measurementRepository); - } - partialWord = partialWord | ((long) read << (len << 3)); - len++; - } - return limit; - } - - private static final long DOT_BITS = 0x10101000; - private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); - - private long process(final long startAddress, final long delimiterAddress, final int hash, final long[] data, final MeasurementRepository measurementRepository) { - - long word = UNSAFE.getLong(delimiterAddress + 1); - if (isBigEndian) { - word = Long.reverseBytes(word); - } - final long invWord = ~word; - final int decimalSepPos = Long.numberOfTrailingZeros(invWord & DOT_BITS); - final long signed = (invWord << 59) >> 63; - final long designMask = ~(signed & 0xFF); - final long digits = ((word & designMask) << (28 - decimalSepPos)) & 0x0F000F0F00L; - final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; - final int measurement = (int) ((absValue ^ signed) - signed); - - // Store: - measurementRepository.update(startAddress, data, (int) (delimiterAddress - startAddress), hash, measurement); - - return delimiterAddress + (decimalSepPos >> 3) + 4; // Determine next start: - // return nextAddress; - } - - static final class Measurement { - int min, max, count; - long sum; - - public Measurement() { - this.min = 1000; - this.max = -1000; + Entry(final long[] data, String city, int temp) { + this.data = data; + this.city = city; + this.min = temp; + this.max = temp; + this.sum = temp; + this.count = 1; } - public Measurement updateWith(int measurement) { - min = min(min, measurement); - max = max(max, measurement); + public void updateWith(int measurement) { + min = Math.min(min, measurement); + max = Math.max(max, measurement); sum += measurement; count++; - return this; } - public Measurement updateWith(Measurement measurement) { - min = min(min, measurement.min); - max = max(max, measurement.max); - sum += measurement.sum; - count += measurement.count; + public Entry mergeWith(Entry entry) { + min = Math.min(min, entry.min); + max = Math.max(max, entry.max); + sum += entry.sum; + count += entry.count; return this; } @@ -221,101 +159,127 @@ public class CalculateAverage_royvanrijn { return round(min) + "/" + round((1.0 * sum) / count) + "/" + round(max); } - private double round(double value) { + private static double round(double value) { return Math.round(value) / 10.0; } } - // branchless max (unprecise for large numbers, but good enough) - static int max(final int a, final int b) { - final int diff = a - b; - final int dsgn = diff >> 31; - return a - (diff & dsgn); + private static Entry createNewEntry(final long[] buffer, final long startAddress, final int lengthLongs, final int lengthBytes, final int temp) { + + // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here. + final byte[] bytes = new byte[lengthBytes]; + UNSAFE.copyMemory(null, startAddress, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, lengthBytes); + final String city = new String(bytes, StandardCharsets.UTF_8); + + final long[] bufferCopy = new long[lengthLongs]; + System.arraycopy(buffer, 0, bufferCopy, 0, lengthLongs); + + // Add the entry: + return new Entry(bufferCopy, city, temp); } - // branchless min (unprecise for large numbers, but good enough) - static int min(final int a, final int b) { - final int diff = a - b; - final int dsgn = diff >> 31; - return b + (diff & dsgn); + private static Entry[] processMemoryArea(final long fromAddress, final long toAddress) { + + Entry[] table = new Entry[TABLE_SIZE]; + + long ptr = fromAddress; + long[] buffer = new long[14]; + + while (ptr < toAddress) { + + int bufferPtr = 0; + long startAddress = ptr; + long hash = 1; + + long word = UNSAFE.getLong(ptr); + long mask = getDelimiterMask(word); + + while (mask == 0) { + buffer[bufferPtr++] = word; + hash ^= word; + ptr += 8; + + word = UNSAFE.getLong(ptr); + mask = getDelimiterMask(word); + } + + // Found delimiter: + final long delimiterAddress = ptr + (Long.numberOfTrailingZeros(mask) >> 3); + final long numberBits = UNSAFE.getLong(delimiterAddress + 1); + + // Finish the masks and hash: + final long partialWord = word & ((mask >> 7) - 1); + buffer[bufferPtr++] = partialWord; + hash ^= partialWord; + + final long invNumberBits = ~numberBits; + final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBits & DOT_BITS); + + // Update counter asap, lets CPU predict. + ptr = delimiterAddress + (decimalSepPos >> 3) + 4; + + int intHash = (int) (hash ^ (hash >>> 31)); // offset for extra entropy + + // Awesome idea of merykitty: + final int temp = extractTemp(numberBits, invNumberBits, decimalSepPos); + + int index = intHash & TABLE_MASK; + + // Find or insert the entry: + while (true) { + Entry tableEntry = table[index]; + if (tableEntry == null) { + final int length = (int) (delimiterAddress - startAddress); + table[index] = createNewEntry(buffer, startAddress, bufferPtr, length, temp); + break; + } + else if (bufferPtr == tableEntry.data.length) { + if (!arrayEquals(buffer, tableEntry.data, bufferPtr)) { + index = (index + 1) & TABLE_MASK; + continue; + } + // No differences in array + tableEntry.updateWith(temp); + break; + } + // Move to the next index + index = (index + 1) & TABLE_MASK; + } + } + return table; } - private static int longHashStep(final int hash, final long word) { - return 31 * hash + (int) (word ^ (word >>> 32)); + private static int extractTemp(final long numberBits, final long invNumberBits, final int decimalSepPos) { + final long signed = (invNumberBits << 59) >> 63; + final long minusFilter = ~(signed & 0xFF); + final long digits = ((numberBits & minusFilter) << (28 - decimalSepPos)) & 0x0F000F0F00L; + final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; // filter just the result + final int temp = (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick + return temp; } + private static long getDelimiterMask(final long word) { + long match = word ^ SEPARATOR_PATTERN; + return (match - 0x0101010101010101L) & ~match & 0x8080808080808080L; + } + + private static final long SEPARATOR_PATTERN = compilePattern((byte) ';'); + private static final long DOT_BITS = 0x10101000; + private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); + private static long compilePattern(final byte value) { return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) | ((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value; } - /** - * A normal Java HashMap does all these safety things like boundary checks... we don't need that, we need speeeed. - * - * So I've written an extremely simple linear probing hashmap that should work well enough. - */ - class MeasurementRepository { - private int tableSize = 1 << 20; // large enough for the contest. - private int tableMask = (tableSize - 1); - - private MeasurementRepository.Entry[] table = new MeasurementRepository.Entry[tableSize]; - - record Entry(long address, long[] data, int length, int hash, String city, Measurement measurement) { - - @Override - public String toString() { - return city + "=" + measurement; - } - } - - public void update(long address, long[] data, int length, int hash, int temperature) { - - int dataLength = length >> 3; - int index = hash & tableMask; - MeasurementRepository.Entry tableEntry; - while ((tableEntry = table[index]) != null - && (tableEntry.hash != hash || tableEntry.length != length || !arrayEquals(tableEntry.data, data, dataLength))) { // search for the right spot - index = (index + 1) & tableMask; - } - - if (tableEntry != null) { - tableEntry.measurement.updateWith(temperature); - return; - } - - // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here. - Measurement measurement = new Measurement(); - - byte[] bytes = new byte[length]; - for (int i = 0; i < length; i++) { - bytes[i] = UNSAFE.getByte(address + i); - } - String city = new String(bytes); - - long[] dataCopy = new long[dataLength]; - System.arraycopy(data, 0, dataCopy, 0, dataLength); - - // And add entry: - MeasurementRepository.Entry toAdd = new MeasurementRepository.Entry(address, dataCopy, length, hash, city, measurement); - table[index] = toAdd; - - toAdd.measurement.updateWith(temperature); - } - - public Stream get() { - return Arrays.stream(table).filter(Objects::nonNull); - } - } - /** * For case multiple hashes are equal (however unlikely) check the actual key (using longs) */ - private boolean arrayEquals(final long[] a, final long[] b, final int length) { + static boolean arrayEquals(final long[] a, final long[] b, final int length) { for (int i = 0; i < length; i++) { if (a[i] != b[i]) return false; } return true; } - }