diff --git a/calculate_average_yavuztas.sh b/calculate_average_yavuztas.sh index bfa7b10..bbcd403 100755 --- a/calculate_average_yavuztas.sh +++ b/calculate_average_yavuztas.sh @@ -15,5 +15,11 @@ # limitations under the License. # -JAVA_OPTS="-Xms128m -Xmx128m -XX:MaxGCPauseMillis=1 -XX:-AlwaysPreTouch -XX:+UseSerialGC --enable-preview" -java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yavuztas +if [ -f target/CalculateAverage_yavuztas_image ]; then + echo "Picking up existing native image 'target/CalculateAverage_yavuztas_image', delete the file to select JVM mode." 1>&2 + target/CalculateAverage_yavuztas_image +else + JAVA_OPTS="-XX:MaxGCPauseMillis=1 -XX:-AlwaysPreTouch -XX:+UseSerialGC -XX:+TieredCompilation --enable-preview" + echo "Choosing to run the app in JVM mode as no native image was found, use prepare_yavuztas.sh to generate." 1>&2 + java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yavuztas +fi diff --git a/prepare_yavuztas.sh b/prepare_yavuztas.sh index f83a3ff..f9871af 100755 --- a/prepare_yavuztas.sh +++ b/prepare_yavuztas.sh @@ -16,4 +16,9 @@ # source "$HOME/.sdkman/bin/sdkman-init.sh" -sdk use java 21.0.1-graal 1>&2 +sdk use java 21.0.2-graal 1>&2 + +if [ ! -f target/CalculateAverage_yavuztas_image ]; then + NATIVE_IMAGE_OPTS="--initialize-at-build-time=dev.morling.onebrc.CalculateAverage_yavuztas --gc=epsilon -O3 -march=native -R:MaxHeapSize=128m -H:-GenLoopSafepoints --enable-preview" + native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_yavuztas_image dev.morling.onebrc.CalculateAverage_yavuztas +fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java b/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java index e33fe7e..0e589a4 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java @@ -17,15 +17,16 @@ package dev.morling.onebrc; import sun.misc.Unsafe; -import java.io.IOException; import java.lang.foreign.Arena; import java.lang.reflect.Field; -import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.TreeMap; import java.util.function.Consumer; @@ -35,8 +36,9 @@ public class CalculateAverage_yavuztas { private static final Unsafe UNSAFE = unsafe(); - // Tried all there: MappedByteBuffer, MemorySegment and Unsafe - // Accessing the memory using Unsafe is still the fastest in my experience + // I compared all three: MappedByteBuffer, MemorySegment and Unsafe. + // Accessing the memory using Unsafe is still the fastest in my experience. + // However, I would never use it in production, single programming error will crash your app. private static Unsafe unsafe() { try { final Field f = Unsafe.class.getDeclaredField("theUnsafe"); @@ -48,296 +50,419 @@ public class CalculateAverage_yavuztas { } } + /** + * Extract bytes from a long + */ + private static long partial(long word, int length) { + final long mask = (~0L) << (length << 3); + return word & (~mask); + } + // Only one object, both for measurements and keys, less object creation in hotpots is always faster - static class Record { + private static final class Record { - // keep memory starting address for each segment - // since we use Unsafe, this is enough to align and fetch the data - long segment; - int start; - int length; - int hash; + private final long start; // memory address of the underlying data + private final int length; + private final long word1; + private final long word2; + private final long wordLast; + private final int hash; + private Record next; // linked list to resolve hash collisions - private int min = 1000; // calculations over int is faster than double, we convert to double in the end only once - private int max = -1000; + private int min; // calculations over int is faster than double, we convert to double in the end only once + private int max; private long sum; - private long count; + private int count; - public Record(long segment, int start, int length, int hash) { - this.segment = segment; + public Record(long start, int length, long word1, long word2, long wordLast, int hash, int temp) { this.start = start; this.length = length; + this.word1 = word1; + this.word2 = word2; + this.wordLast = wordLast; this.hash = hash; + this.min = temp; + this.max = temp; + this.sum = temp; + this.count = 1; } @Override public boolean equals(Object o) { final Record record = (Record) o; - return equals(record.segment, record.start, record.length, record.hash); + return equals(record.start, record.word1, record.word2, record.wordLast, record.length); } - /** - * Stateless equals, no Record object needed - */ - public boolean equals(long segment, int start, int length, int hash) { - if (this.length != length || this.hash != hash) + private static boolean notEquals(long address1, long address2, int step) { + return UNSAFE.getLong(address1 + step) != UNSAFE.getLong(address2 + step); + } + + private static boolean equalsComparingLongs(long start1, long start2, int length) { + // first shortcuts + if (length < 24) + return true; + if (length < 32) + return !notEquals(start1, start2, 16); + + int step = 24; // starting from 3rd long + length -= step; + while (length >= 8) { // scan longs + if (notEquals(start1, start2, step)) { + return false; + } + length -= 8; + step += 8; // 8 bytes + } + return true; + } + + private boolean equals(long start, long word1, long word2, long last, int length) { + if (this.word1 != word1) + return false; + if (this.word2 != word2) return false; - int i = 0; // bytes mismatch check - while (i < this.length - && UNSAFE.getByte(this.segment + this.start + i) == UNSAFE.getByte(segment + start + i)) { - i++; - } - return i == this.length; - } - - @Override - public int hashCode() { - return this.hash; + // equals check is done by comparing longs instead of byte by byte check, this is faster + return equalsComparingLongs(this.start, start, length) && this.wordLast == last; } @Override public String toString() { final byte[] bytes = new byte[this.length]; - int i = 0; - while (i < this.length) { - bytes[i] = UNSAFE.getByte(this.segment + this.start + i++); - } - + UNSAFE.copyMemory(null, this.start, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, this.length); return new String(bytes, StandardCharsets.UTF_8); } - public Record collect(int temp) { - this.min = Math.min(this.min, temp); - this.max = Math.max(this.max, temp); + private void collect(int temp) { + if (temp < this.min) + this.min = temp; + if (temp > this.max) + this.max = temp; this.sum += temp; this.count++; - return this; } - public void merge(Record other) { - this.min = Math.min(this.min, other.min); - this.max = Math.max(this.max, other.max); - this.sum += other.sum; - this.count += other.count; + private void merge(Record that) { + if (that.min < this.min) + this.min = that.min; + if (that.max > this.max) + this.max = that.max; + this.sum += that.sum; + this.count += that.count; } - public String measurements() { + private String measurements() { // here is only executed once for each unique key, so StringBuilder creation doesn't harm final StringBuilder sb = new StringBuilder(14); - sb.append(this.min / 10.0); - sb.append("/"); - sb.append(round((this.sum / 10.0) / this.count)); - sb.append("/"); - sb.append(this.max / 10.0); + sb.append(round(this.min)).append("/"); + sb.append(round(1.0 * this.sum / this.count)).append("/"); + sb.append(round(this.max)); return sb.toString(); } } // Inspired by @spullara - customized hashmap on purpose - // The main difference is we hold only one array instead of two - static class RecordMap { + // The main difference is we hold only one array instead of two, fewer objects is faster + private static final class RecordMap { - static final int SIZE = 1 << 15; // 32k - bigger bucket size less collisions - static final int BITMASK = SIZE - 1; - Record[] keys = new Record[SIZE]; + // Bigger bucket size less collisions, but you have to find a sweet spot otherwise it is becoming slower. + // Also works good enough for 10K stations + private static final int SIZE = 1 << 14; // 16kb - enough for 10K + private static final int BITMASK = SIZE - 1; + private final Record[] keys = new Record[SIZE]; - static int hashBucket(int hash) { + // int collision; + + private boolean hasNoRecord(int index) { + return this.keys[index] == null; + } + + private Record getRecord(int index) { + return this.keys[index]; + } + + private static int hashBucket(int hash) { hash = hash ^ (hash >>> 16); // naive bit spreading but surprisingly decreases collision :) return hash & BITMASK; // fast modulo, to find bucket } - void putAndCollect(long segment, int start, int length, int hash, int temp) { - int bucket = hashBucket(hash); - Record existing = this.keys[bucket]; - if (existing == null) { - this.keys[bucket] = new Record(segment, start, length, hash) - .collect(temp); + private void putAndCollect(int hash, int temp, long start, int length, long word1, long word2, long wordLast) { + final int bucket = hashBucket(hash); + if (hasNoRecord(bucket)) { + this.keys[bucket] = new Record(start, length, word1, word2, wordLast, hash, temp); return; } - if (!existing.equals(segment, start, length, hash)) { - // collision, linear probing to find a slot - while ((existing = this.keys[++bucket & BITMASK]) != null && !existing.equals(segment, start, length, hash)) { - // can be stuck here if all the buckets are full :( - // However, since the data set is max 10K (unique) this shouldn't happen - // So, I'm happily leave here branchless :) - } - if (existing == null) { - this.keys[bucket & BITMASK] = new Record(segment, start, length, hash) - .collect(temp); + Record existing = getRecord(bucket); + if (existing.equals(start, word1, word2, wordLast, length)) { + existing.collect(temp); + return; + } + + // collision++; + // find possible slot by scanning the slot linked list + while (existing.next != null) { + if (existing.next.equals(start, word1, word2, wordLast, length)) { + existing.next.collect(temp); return; } - existing.collect(temp); - } - else { - existing.collect(temp); + existing = existing.next; // go on to next + // collision++; } + existing.next = new Record(start, length, word1, word2, wordLast, hash, temp); } - void putOrMerge(Record key) { - int bucket = hashBucket(key.hash); - Record existing = this.keys[bucket]; - if (existing == null) { + private void putOrMerge(Record key) { + final int bucket = hashBucket(key.hash); + if (hasNoRecord(bucket)) { + key.next = null; this.keys[bucket] = key; return; } - if (!existing.equals(key)) { - // collision, linear probing to find a slot - while ((existing = this.keys[++bucket & BITMASK]) != null && !existing.equals(key)) { - // can be stuck here if all the buckets are full :( - // However, since the data set is max 10K (unique keys) this shouldn't happen - // So, I'm happily leave here branchless :) - } - if (existing == null) { - this.keys[bucket & BITMASK] = key; + Record existing = getRecord(bucket); + if (existing.equals(key)) { + existing.merge(key); + return; + } + + // collision++; + // find possible slot by scanning the slot linked list + while (existing.next != null) { + if (existing.next.equals(key)) { + existing.next.merge(key); return; } - existing.merge(key); - } - else { - existing.merge(key); + existing = existing.next; // go on to next + // collision++; } + key.next = null; + existing.next = key; } - void forEach(Consumer consumer) { + private void forEach(Consumer consumer) { int pos = 0; Record key; - while (pos < this.keys.length) { + while (pos < SIZE) { if ((key = this.keys[pos++]) == null) { continue; } + Record next = key.next; consumer.accept(key); + while (next != null) { // also traverse the records in the collision list + final Record tmp = next.next; + consumer.accept(next); + next = tmp; + } } } - void merge(RecordMap other) { + private void merge(RecordMap other) { other.forEach(this::putOrMerge); } } // One actor for one thread, no synchronization - static class RegionActor { + private static final class RegionActor extends Thread { - final FileChannel channel; - final long startPos; - final int size; - final RecordMap map = new RecordMap(); - long segmentAddress; - int position; - Thread runner; // each actor has its own thread + private final long startPos; // start of region memory address + private final int size; - public RegionActor(FileChannel channel, long startPos, int size) { - this.channel = channel; + private final RecordMap map = new RecordMap(); + + public RegionActor(long startPos, int size) { this.startPos = startPos; this.size = size; } - void accumulate() { - this.runner = new Thread(() -> { - try { - // get the segment memory address, this is the only thing we need for Unsafe - this.segmentAddress = this.channel.map(FileChannel.MapMode.READ_ONLY, this.startPos, this.size, Arena.global()).address(); - } - catch (IOException e) { - // no-op - skip intentionally, no handling for the purpose of this challenge - } + private static long getWord(long address) { + return UNSAFE.getLong(address); + } - int start; - int keyHash; - int length; - while (this.position < this.size) { - byte b; - start = this.position; // save line start position - keyHash = UNSAFE.getByte(this.segmentAddress + this.position++); // first byte is guaranteed not to be ';' - length = 1; // min key length - while ((b = UNSAFE.getByte(this.segmentAddress + this.position++)) != ';') { // read until semicolon - keyHash = calculateHash(keyHash, b); // calculate key hash ahead, eleminates one more loop later - length++; + // hasvalue & haszero + // adapted from https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord + private static long hasSemicolon(long word) { + // semicolon pattern + final long hasVal = word ^ 0x3B3B3B3B3B3B3B3BL; // hasvalue + return ((hasVal - 0x0101010101010101L) & ~hasVal & 0x8080808080808080L); // haszero + } + + private static int semicolonPos(long hasVal) { + return Long.numberOfTrailingZeros(hasVal) >>> 3; + } + + private static int decimalPos(long numberWord) { + return Long.numberOfTrailingZeros(~numberWord & 0x10101000); + } + + private static final int MAX_INNER_LOOP_SIZE = 11; + + @Override + public void run() { + long pointer = this.startPos; + final long size = pointer + this.size; + while (pointer < size) { // line start + long hash = 0; // reset hash + long s; // semicolon check word + final int pos; // semicolon position + long word1 = getWord(pointer); + if ((s = hasSemicolon(word1)) != 0) { + pos = semicolonPos(s); + // read temparature + final long numberWord = getWord(pointer + pos + 1); + final int decimalPos = decimalPos(numberWord); + final int temp = convertIntoNumber(decimalPos, numberWord); + + word1 = partial(word1, pos); // last word + this.map.putAndCollect(completeHash(hash, word1), temp, pointer, pos, word1, 0, 0); + + pointer += pos + (decimalPos >>> 3) + 4; + } + else { + long word2 = getWord(pointer + 8); + if ((s = hasSemicolon(word2)) != 0) { + pos = semicolonPos(s); + // read temparature + final int length = pos + 8; + final long numberWord = getWord(pointer + length + 1); + final int decimalPos = decimalPos(numberWord); + final int temp = convertIntoNumber(decimalPos, numberWord); + + word2 = partial(word2, pos); // last word + this.map.putAndCollect(completeHash(hash, word1, word2), temp, pointer, length, word1, word2, 0); + + pointer += length + (decimalPos >>> 3) + 4; // seek to the line end } + else { + long word = 0; + int length = 16; + hash = appendHash(hash, word1, word2); + // Let the compiler know the loop size ahead + // Then it's automatically unrolled + // Max key length is 13 longs, 2 we've read before, 11 left + for (int i = 0; i < MAX_INNER_LOOP_SIZE; i++) { + if ((s = hasSemicolon((word = getWord(pointer + length)))) != 0) { + break; + } + hash = appendHash(hash, word); + length += 8; + } - final int temp = readTemperature(); - this.map.putAndCollect(this.segmentAddress, start, length, keyHash, temp); + pos = semicolonPos(s); + length += pos; + // read temparature + final long numberWord = getWord(pointer + length + 1); + final int decimalPos = decimalPos(numberWord); + final int temp = convertIntoNumber(decimalPos, numberWord); - this.position++; // skip linebreak - } - }); - this.runner.start(); - } + word = partial(word, pos); // last word + this.map.putAndCollect(completeHash(hash, word), temp, pointer, length, word1, word2, word); - static int calculateHash(int hash, int b) { - return 31 * hash + b; - } - - // 1. Inspired by @yemreinci - Reading temparature value without Double.parse - // 2. Inspired by @obourgain - Fetching first 4 bytes ahead, then masking - int readTemperature() { - int temp = 0; - // read 4 bytes ahead - final int first4 = UNSAFE.getInt(this.segmentAddress + this.position); - this.position += 3; - - final byte b1 = (byte) first4; // first byte - final byte b2 = (byte) ((first4 >> 8) & 0xFF); // second byte - final byte b3 = (byte) ((first4 >> 16) & 0xFF); // third byte - if (b1 == '-') { - if (b3 == '.') { - temp -= 10 * (b2 - '0') + (byte) ((first4 >> 24) & 0xFF) - '0'; // fourth byte - this.position++; - } - else { - this.position++; // skip dot - temp -= 100 * (b2 - '0') + 10 * (b3 - '0') + UNSAFE.getByte(this.segmentAddress + this.position++) - '0'; // fifth byte - } - } - else { - if (b2 == '.') { - temp = 10 * (b1 - '0') + b3 - '0'; - } - else { - temp = 100 * (b1 - '0') + 10 * (b2 - '0') + (byte) ((first4 >> 24) & 0xFF) - '0'; // fourth byte - this.position++; + pointer += length + (decimalPos >>> 3) + 4; // seek to the line end + } } } + } - return temp; + // Hashes are calculated by a Mersenne Prime (1 << 7) -1 + // This is faster than multiplication in some machines + private static long appendHash(long hash, long word) { + return (hash << 7) - hash + word; + } + + private static long appendHash(long hash, long word1, long word2) { + hash = (hash << 7) - hash + word1; + return (hash << 7) - hash + word2; + } + + private static int completeHash(long hash, long partial) { + hash = (hash << 7) - hash + partial; + return (int) (hash ^ (hash >>> 25)); + } + + private static int completeHash(long hash, long word1, long word2) { + hash = (hash << 7) - hash + word1; + hash = (hash << 7) - hash + word2; + return (int) hash ^ (int) (hash >>> 25); + } + + // Credits to @merrykitty. Magical solution to parse temparature values branchless! + // Taken as without modification, comments belong to @merrykitty + private static int convertIntoNumber(int decimalSepPos, long numberWord) { + final int shift = 28 - decimalSepPos; + // signed is -1 if negative, 0 otherwise + final long signed = (~numberWord << 59) >> 63; + final long designMask = ~(signed & 0xFF); + // Align the number to a specific position and transform the ascii code + // to actual digit value in each byte + final long digits = ((numberWord & designMask) << shift) & 0x0F000F0F00L; + // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit) + // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) = + // 0x000000UU00TTHH00 + + // 0x00UU00TTHH000000 * 10 + + // 0xUU00TTHH00000000 * 100 + // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400 + // This results in our value lies in the bit 32 to 41 of this product + // That was close :) + final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; + final long value = (absValue ^ signed) - signed; + return (int) value; } /** * blocks until the map is fully collected */ - RecordMap get() throws InterruptedException { - this.runner.join(); + private RecordMap get() throws InterruptedException { + join(); return this.map; } } private static double round(double value) { - return Math.round(value * 10.0) / 10.0; + return Math.round(value) / 10.0; } /** * Scans the given buffer to the left */ - private static long findClosestLineEnd(long start, int size, FileChannel channel) throws IOException { - final long position = start + size; - final long left = Math.max(position - 101, 0); - final ByteBuffer buffer = ByteBuffer.allocate(101); // enough size to find at least one '\n' - if (channel.read(buffer.clear(), left) != -1) { - int bufferPos = buffer.position() - 1; - while (buffer.get(bufferPos) != '\n') { - bufferPos--; - size--; - } + private static long findClosestLineEnd(long start, int size) { + long position = start + size; + while (UNSAFE.getByte(--position) != '\n') { + // read until a linebreak + size--; } return size; } - public static void main(String[] args) throws IOException, InterruptedException { + private static boolean isWorkerProcess(String[] args) { + return Arrays.asList(args).contains("--worker"); + } - var concurrency = Runtime.getRuntime().availableProcessors(); + private static void runAsWorker() throws Exception { + final ProcessHandle.Info info = ProcessHandle.current().info(); + final List commands = new ArrayList<>(); + info.command().ifPresent(commands::add); + info.arguments().ifPresent(args -> commands.addAll(Arrays.asList(args))); + commands.add("--worker"); + + new ProcessBuilder() + .command(commands) + .start() + .getInputStream() + .transferTo(System.out); + } + + public static void main(String[] args) throws Exception { + + // Dased on @thomaswue's idea, to cut unmapping delay. + // Strangely, unmapping delay doesn't occur on macOS/M1 however in Linux/AMD it's substantial - ~200ms + if (!isWorkerProcess(args)) { + runAsWorker(); + return; + } + + var concurrency = 2 * Runtime.getRuntime().availableProcessors(); final long fileSize = Files.size(FILE); long regionSize = fileSize / concurrency; @@ -353,30 +478,36 @@ public class CalculateAverage_yavuztas { long startPos = 0; final FileChannel channel = (FileChannel) Files.newByteChannel(FILE, StandardOpenOption.READ); + // get the memory address, this is the only thing we need for Unsafe + final long memoryAddress = channel.map(FileChannel.MapMode.READ_ONLY, startPos, fileSize, Arena.global()).address(); + final RegionActor[] actors = new RegionActor[concurrency]; for (int i = 0; i < concurrency; i++) { // calculate boundaries long maxSize = (startPos + regionSize > fileSize) ? fileSize - startPos : regionSize; // shift position to back until we find a linebreak - maxSize = findClosestLineEnd(startPos, (int) maxSize, channel); + maxSize = findClosestLineEnd(memoryAddress + startPos, (int) maxSize); - final RegionActor region = (actors[i] = new RegionActor(channel, startPos, (int) maxSize)); - region.accumulate(); + final RegionActor region = (actors[i] = new RegionActor(memoryAddress + startPos, (int) maxSize)); + region.start(); // start processing startPos += maxSize; } - final RecordMap output = new RecordMap(); // output to merge all regions + final RecordMap output = new RecordMap(); // output to merge all records for (RegionActor actor : actors) { final RecordMap partial = actor.get(); // blocks until get the result output.merge(partial); + // System.out.println("collisions: " + partial.collision); } // sort and print the result final TreeMap sorted = new TreeMap<>(); - output.forEach(key -> sorted.put(key.toString(), key.measurements())); + output.forEach(key -> { + sorted.put(key.toString(), key.measurements()); + }); System.out.println(sorted); - + System.out.close(); // closing the stream will trigger the main process to pick up the output early } }