diff --git a/calculate_average_mtopolnik.sh b/calculate_average_mtopolnik.sh index 24b5a1c..acd1024 100755 --- a/calculate_average_mtopolnik.sh +++ b/calculate_average_mtopolnik.sh @@ -15,5 +15,11 @@ # limitations under the License. # -java --enable-preview \ - --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_mtopolnik +if [ -f target/CalculateAverage_mtopolnik_image ]; then + echo "Using native image 'target/CalculateAverage_mtopolnik_image'" 1>&2 + target/CalculateAverage_mtopolnik_image +else + JAVA_OPTS="--enable-preview" + echo "Native image not found, using JVM mode." 1>&2 + java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_mtopolnik +fi diff --git a/prepare_mtopolnik.sh b/prepare_mtopolnik.sh index f83a3ff..d84f20d 100755 --- a/prepare_mtopolnik.sh +++ b/prepare_mtopolnik.sh @@ -16,4 +16,9 @@ # source "$HOME/.sdkman/bin/sdkman-init.sh" -sdk use java 21.0.1-graal 1>&2 +sdk use java 21.0.2-graal 1>&2 + +if [ ! -f target/CalculateAverage_mtopolnik_image ]; then + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -H:+UnlockExperimentalVMOptions -H:-GenLoopSafepoints -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_mtopolnik" + native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_mtopolnik_image dev.morling.onebrc.CalculateAverage_mtopolnik +fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java index 51ea415..61294a4 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java @@ -29,18 +29,15 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; +import static java.lang.ProcessBuilder.Redirect.PIPE; +import static java.util.Arrays.asList; + public class CalculateAverage_mtopolnik { private static final Unsafe UNSAFE = unsafe(); private static final int MAX_NAME_LEN = 100; private static final int STATS_TABLE_SIZE = 1 << 16; private static final int TABLE_INDEX_MASK = STATS_TABLE_SIZE - 1; private static final String MEASUREMENTS_TXT = "measurements.txt"; - private static final byte SEMICOLON = ';'; - private static final long BROADCAST_SEMICOLON = broadcastByte(SEMICOLON); - - // These two are just informative, I let the IDE calculate them for me - private static final long NATIVE_MEM_PER_THREAD = StatsAccessor.SIZEOF * STATS_TABLE_SIZE; - private static final long NATIVE_MEM_ON_8_THREADS = 8 * NATIVE_MEM_PER_THREAD; private static Unsafe unsafe() { try { @@ -53,31 +50,23 @@ public class CalculateAverage_mtopolnik { } } - static class StationStats implements Comparable { - String name; - long sum; - int count; - int min; - int max; - - @Override - public String toString() { - return String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum / count) / 10.0, max / 10.0); - } - - @Override - public boolean equals(Object that) { - return that.getClass() == StationStats.class && ((StationStats) that).name.equals(this.name); - } - - @Override - public int compareTo(StationStats that) { - return name.compareTo(that.name); - } - } - public static void main(String[] args) throws Exception { - calculate(); + if (args.length >= 1 && args[0].equals("--worker")) { + calculate(); + System.out.close(); + return; + } + var curProcInfo = ProcessHandle.current().info(); + var cmdLine = new ArrayList(); + cmdLine.add(curProcInfo.command().get()); + cmdLine.addAll(asList(curProcInfo.arguments().get())); + cmdLine.add("--worker"); + var process = new ProcessBuilder() + .command(cmdLine) + .inheritIO().redirectOutput(PIPE) + .start() + .getInputStream().transferTo(System.out); + } static void calculate() throws Exception { @@ -113,7 +102,6 @@ public class CalculateAverage_mtopolnik { } private static class ChunkProcessor implements Runnable { - private static final long NAMEBUF_SIZE = 2 * Long.BYTES; private static final int CACHELINE_SIZE = 64; private final long inputBase; @@ -122,8 +110,6 @@ public class CalculateAverage_mtopolnik { private final int myIndex; private StatsAccessor stats; - private long nameBufBase; - private long cursor; ChunkProcessor(long chunkStart, long chunkLimit, StationStats[][] results, int myIndex) { this.inputBase = chunkStart; @@ -138,16 +124,12 @@ public class CalculateAverage_mtopolnik { long totalAllocated = 0; String threadName = Thread.currentThread().getName(); long statsByteSize = STATS_TABLE_SIZE * StatsAccessor.SIZEOF; - var diagnosticString = String.format("Thread %s needs %,d bytes, managed to allocate before OOM: ", - threadName, statsByteSize + NAMEBUF_SIZE); + var diagnosticString = String.format("Thread %s needs %,d bytes", threadName, statsByteSize); try { stats = new StatsAccessor(confinedArena.allocate(statsByteSize, CACHELINE_SIZE)); - totalAllocated = statsByteSize; - nameBufBase = confinedArena.allocate(NAMEBUF_SIZE).address(); } catch (OutOfMemoryError e) { System.err.print(diagnosticString); - System.err.println(totalAllocated); throw e; } processChunk(); @@ -155,227 +137,110 @@ public class CalculateAverage_mtopolnik { } } - private static final int MAX_TEMPERATURE_LEN = 5; - private static final int MAX_ROW_LEN = MAX_NAME_LEN + 1 + MAX_TEMPERATURE_LEN + 1; - private static final long DANGER_ZONE_LENGTH = ((MAX_ROW_LEN - 1) / 8 * 8 + 8); - private void processChunk() { + final long inputSize = this.inputSize; + final long inputBase = this.inputBase; + long cursor = 0; + long lastNameWord; while (cursor < inputSize) { - boolean withinSafeZone; - long word1; - long word2; - long nameLen; long nameStartAddress = inputBase + cursor; - if (cursor + DANGER_ZONE_LENGTH <= inputSize) { - withinSafeZone = true; - word1 = UNSAFE.getLong(nameStartAddress); - word2 = UNSAFE.getLong(nameStartAddress + Long.BYTES); - nameLen = nameLen(word1, word2, withinSafeZone); - word1 = maskWord(word1, nameLen); - word2 = maskWord(word2, nameLen - Long.BYTES); + long nameWord0 = UNSAFE.getLong(nameStartAddress); + long nameWord1 = 0; + long matchBits = semicolonMatchBits(nameWord0); + long hash; + int nameLen; + int temperature; + if (matchBits != 0) { + nameLen = nameLen(matchBits); + nameWord0 = maskWord(nameWord0, matchBits); + cursor += nameLen; + long tempWord = UNSAFE.getLong(inputBase + cursor); + int dotPos = dotPos(tempWord); + temperature = parseTemperature(tempWord, dotPos); + cursor += (dotPos >> 3) + 3; + hash = hash(nameWord0); + if (stats.gotoName0(hash, nameWord0)) { + stats.observe(temperature); + continue; + } + lastNameWord = nameWord0; } - else { - withinSafeZone = false; - UNSAFE.putLong(nameBufBase, 0); - UNSAFE.putLong(nameBufBase + Long.BYTES, 0); - UNSAFE.copyMemory(nameStartAddress, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor)); - word1 = UNSAFE.getLong(nameBufBase); - word2 = UNSAFE.getLong(nameBufBase + Long.BYTES); - nameLen = nameLen(word1, word2, withinSafeZone); + else { // nameLen > 8 + hash = hash(nameWord0); + nameWord1 = UNSAFE.getLong(nameStartAddress + Long.BYTES); + matchBits = semicolonMatchBits(nameWord1); + if (matchBits != 0) { + nameLen = Long.BYTES + nameLen(matchBits); + nameWord1 = maskWord(nameWord1, matchBits); + cursor += nameLen; + long tempWord = UNSAFE.getLong(inputBase + cursor); + int dotPos = dotPos(tempWord); + temperature = parseTemperature(tempWord, dotPos); + cursor += (dotPos >> 3) + 3; + if (stats.gotoName1(hash, nameWord0, nameWord1)) { + stats.observe(temperature); + continue; + } + lastNameWord = nameWord1; + } + else { // nameLen > 16 + nameLen = 2 * Long.BYTES; + while (true) { + lastNameWord = UNSAFE.getLong(nameStartAddress + nameLen); + matchBits = semicolonMatchBits(lastNameWord); + if (matchBits != 0) { + nameLen += nameLen(matchBits); + lastNameWord = maskWord(lastNameWord, matchBits); + cursor += nameLen; + long tempWord = UNSAFE.getLong(inputBase + cursor); + int dotPos = dotPos(tempWord); + temperature = parseTemperature(tempWord, dotPos); + cursor += (dotPos >> 3) + 3; + break; + } + nameLen += Long.BYTES; + } + } } - long hash = hash(word1); - assert nameLen > 0 && nameLen <= 100 : nameLen; - long tempStartAddress = nameStartAddress + nameLen + 1; - int temperature = withinSafeZone - ? parseTemperatureSwarAndAdvanceCursor(tempStartAddress) - : parseTemperatureSimpleAndAdvanceCursor(tempStartAddress); - updateStats(hash, nameStartAddress, nameLen, word1, word2, temperature, withinSafeZone); + stats.gotoAndObserve(hash, nameStartAddress, nameLen, nameWord0, nameWord1, lastNameWord, temperature); } } - private void updateStats( - long hash, long nameStartAddress, long nameLen, long nameWord1, long nameWord2, - int temperature, boolean withinSafeZone) { - int tableIndex = (int) (hash & TABLE_INDEX_MASK); - while (true) { - stats.gotoIndex(tableIndex); - if (stats.hash() == hash && stats.nameLen() == nameLen && nameEquals( - stats.nameAddress(), nameStartAddress, nameLen, nameWord1, nameWord2, withinSafeZone)) { - stats.setSum(stats.sum() + temperature); - stats.setCount(stats.count() + 1); - stats.setMin((short) Integer.min(stats.min(), temperature)); - stats.setMax((short) Integer.max(stats.max(), temperature)); - return; - } - if (stats.nameLen() != 0) { - tableIndex = (tableIndex + 1) & TABLE_INDEX_MASK; - continue; - } - stats.setHash(hash); - stats.setNameLen((int) nameLen); - stats.setSum(temperature); - stats.setCount(1); - stats.setMin((short) temperature); - stats.setMax((short) temperature); - UNSAFE.copyMemory(nameStartAddress, stats.nameAddress(), nameLen); - return; - } + private static final long BROADCAST_SEMICOLON = 0x3B3B3B3B3B3B3B3BL; + private static final long BROADCAST_0x01 = 0x0101010101010101L; + private static final long BROADCAST_0x80 = 0x8080808080808080L; + + private static long semicolonMatchBits(long word) { + long diff = word ^ BROADCAST_SEMICOLON; + return (diff - BROADCAST_0x01) & (~diff & BROADCAST_0x80); } - // Credit: merykitty - private int parseTemperatureSwarAndAdvanceCursor(long tempStartAddress) { - long word = UNSAFE.getLong(tempStartAddress); - final long negated = ~word; - final int dotPos = Long.numberOfTrailingZeros(negated & 0x10101000); - cursor = (tempStartAddress + (dotPos / 8) + 3) - inputBase; - final long signed = (negated << 59) >> 63; + // credit: artsiomkorzun + private static long maskWord(long word, long matchBits) { + long mask = matchBits ^ (matchBits - 1); + return word & mask; + } + + // credit: merykitty + private static int dotPos(long word) { + return Long.numberOfTrailingZeros(~word & 0x10101000); + } + + // credit: merykitty + private static int parseTemperature(long word, int dotPos) { + final long signed = (~word << 59) >> 63; final long removeSignMask = ~(signed & 0xFF); final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L; final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; return (int) ((absValue ^ signed) - signed); } - private int parseTemperatureSimpleAndAdvanceCursor(long tempStartAddress) { - final byte minus = (byte) '-'; - final byte zero = (byte) '0'; - final byte dot = (byte) '.'; - - byte ch = UNSAFE.getByte(tempStartAddress); - long address = tempStartAddress; - int temperature; - int sign; - if (ch == minus) { - sign = -1; - address++; - ch = UNSAFE.getByte(address); - } - else { - sign = 1; - } - temperature = ch - zero; - address++; - ch = UNSAFE.getByte(address); - if (ch == dot) { - address++; - ch = UNSAFE.getByte(address); - } - else { - temperature = 10 * temperature + (ch - zero); - address += 2; - ch = UNSAFE.getByte(address); - } - temperature = 10 * temperature + (ch - zero); - // address - inputBase is the length of the temperature field. - // A newline character follows the temperature, and so we advance - // the cursor past the newline to the start of the next line. - cursor = (address + 2) - inputBase; - return sign * temperature; + private static int nameLen(long separator) { + return (Long.numberOfTrailingZeros(separator) >>> 3) + 1; } - private static long hash(long word1) { - long seed = 0x51_7c_c1_b7_27_22_0a_95L; - int rotDist = 17; - - long hash = word1; - hash *= seed; - hash = Long.rotateLeft(hash, rotDist); - // hash ^= word2; - // hash *= seed; - // hash = Long.rotateLeft(hash, rotDist); - return hash; - } - - private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2, - boolean withinSafeZone) { - boolean mismatch1 = maskWord(inputWord1, len) != UNSAFE.getLong(statsAddr); - boolean mismatch2 = maskWord(inputWord2, len - Long.BYTES) != UNSAFE.getLong(statsAddr + Long.BYTES); - if (len <= 2 * Long.BYTES) { - return !(mismatch1 | mismatch2); - } - if (withinSafeZone) { - int i = 2 * Long.BYTES; - for (; i <= len - Long.BYTES; i += Long.BYTES) { - if (UNSAFE.getLong(inputAddr + i) != UNSAFE.getLong(statsAddr + i)) { - return false; - } - } - return maskWord(UNSAFE.getLong(inputAddr + i), len - i) == UNSAFE.getLong(statsAddr + i); - } - else { - for (int i = 2 * Long.BYTES; i < len; i++) { - if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) { - return false; - } - } - } - return true; - } - - private static long maskWord(long word, long len) { - long halfShiftDistance = Long.max(0, Long.BYTES - len) << 2; - long mask = (~0L >>> halfShiftDistance) >>> halfShiftDistance; // avoid Java trap of shiftDist % 64 - return word & mask; - } - - private static final long BROADCAST_0x01 = broadcastByte(0x01); - private static final long BROADCAST_0x80 = broadcastByte(0x80); - - // Adapted from https://jameshfisher.com/2017/01/24/bitwise-check-for-zero-byte/ - // and https://github.com/ashvardanian/StringZilla/blob/14e7a78edcc16b031c06b375aac1f66d8f19d45a/stringzilla/stringzilla.h#L139-L169 - long nameLen(long word1, long word2, boolean withinSafeZone) { - { - long matchBits1 = matchBits(word1); - long matchBits2 = matchBits(word2); - if ((matchBits1 | matchBits2) != 0) { - int trailing1 = Long.numberOfTrailingZeros(matchBits1); - int match1IsNonZero = trailing1 & 63; - match1IsNonZero |= match1IsNonZero >>> 3; - match1IsNonZero |= match1IsNonZero >>> 1; - match1IsNonZero |= match1IsNonZero >>> 1; - // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to - // raise the lowest bit in trailing2 if trailing1 is nonzero. This forces - // trailing2 to be zero if trailing1 is non-zero. - int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63; - // trailing1 | trailing2 works like trailing1 + trailing2 because if trailing2 is non-zero, - // then trailing1 is 64, and since trailing2 is < 64, there's no bit overlap. - return (trailing1 | trailing2) >> 3; - } - } - long nameStartAddress = inputBase + cursor; - long address = nameStartAddress + 2 * Long.BYTES; - long limit = inputBase + inputSize; - if (withinSafeZone) { - for (; address < limit; address += Long.BYTES) { - var block = maskWord(UNSAFE.getLong(address), limit - address); - long matchBits = matchBits(block); - if (matchBits != 0) { - return address + (Long.numberOfTrailingZeros(matchBits) >> 3) - nameStartAddress; - } - } - throw new RuntimeException("Semicolon not found"); - } - return addrOfSemicolonSafe(address, limit) - nameStartAddress; - } - - private static long addrOfSemicolonSafe(long address, long limit) { - for (; address < limit - Long.BYTES + 1; address += Long.BYTES) { - var block = UNSAFE.getLong(address); - long matchBits = matchBits(block); - if (matchBits != 0) { - return address + (Long.numberOfTrailingZeros(matchBits) >> 3); - } - } - for (; address < limit; address++) { - if (UNSAFE.getByte(address) == SEMICOLON) { - return address; - } - } - throw new RuntimeException("Semicolon not found"); - } - - private static long matchBits(long word) { - long diff = word ^ BROADCAST_SEMICOLON; - return (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; + private static long hash(long word) { + return Long.rotateLeft(word * 0x51_7c_c1_b7_27_22_0a_95L, 17); } // Copies the results from native memory to Java heap and puts them into the results array. @@ -403,22 +268,6 @@ public class CalculateAverage_mtopolnik { Arrays.sort(exported); results[myIndex] = exported; } - - private final ByteBuffer buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder()); - - private String longToString(long word) { - buf.clear(); - buf.putLong(word); - return new String(buf.array(), StandardCharsets.UTF_8); // + "|" + Arrays.toString(buf.array()); - } - } - - private static long broadcastByte(int b) { - long nnnnnnnn = b; - nnnnnnnn |= nnnnnnnn << 8; - nnnnnnnn |= nnnnnnnn << 16; - nnnnnnnn |= nnnnnnnn << 32; - return nnnnnnnn; } static class StatsAccessor { @@ -446,6 +295,16 @@ public class CalculateAverage_mtopolnik { slotBase = address + index * SIZEOF; } + private boolean gotoName0(long hash, long nameWord0) { + gotoIndex((int) (hash & TABLE_INDEX_MASK)); + return hash() == hash && nameWord0() == nameWord0; + } + + private boolean gotoName1(long hash, long nameWord0, long nameWord1) { + gotoIndex((int) (hash & TABLE_INDEX_MASK)); + return hash() == hash && nameWord0() == nameWord0 && nameWord1() == nameWord1; + } + long hash() { return UNSAFE.getLong(slotBase + HASH_OFFSET); } @@ -474,9 +333,17 @@ public class CalculateAverage_mtopolnik { return slotBase + NAME_OFFSET; } + long nameWord0() { + return UNSAFE.getLong(nameAddress()); + } + + long nameWord1() { + return UNSAFE.getLong(nameAddress() + Long.BYTES); + } + String exportNameString() { - final var bytes = new byte[nameLen()]; - UNSAFE.copyMemory(null, nameAddress(), bytes, ARRAY_BASE_OFFSET, nameLen()); + final var bytes = new byte[nameLen() - 1]; + UNSAFE.copyMemory(null, nameAddress(), bytes, ARRAY_BASE_OFFSET, bytes.length); return new String(bytes, StandardCharsets.UTF_8); } @@ -503,6 +370,59 @@ public class CalculateAverage_mtopolnik { void setMax(short max) { UNSAFE.putShort(slotBase + MAX_OFFSET, max); } + + void gotoAndObserve( + long hash, long nameStartAddress, int nameLen, long nameWord0, long nameWord1, long lastNameWord, + int temperature) { + int tableIndex = (int) (hash & TABLE_INDEX_MASK); + while (true) { + gotoIndex(tableIndex); + if (hash() == hash && nameLen() == nameLen && nameEquals( + nameAddress(), nameStartAddress, nameLen, nameWord0, nameWord1, lastNameWord)) { + observe(temperature); + break; + } + if (nameLen() != 0) { + tableIndex = (tableIndex + 1) & TABLE_INDEX_MASK; + continue; + } + initialize(hash, nameLen, nameStartAddress, temperature); + break; + } + } + + void initialize(long hash, long nameLen, long nameStartAddress, int temperature) { + setHash(hash); + setNameLen((int) nameLen); + setSum(temperature); + setCount(1); + setMin((short) temperature); + setMax((short) temperature); + UNSAFE.copyMemory(nameStartAddress, nameAddress(), nameLen); + } + + void observe(int temperature) { + setSum(sum() + temperature); + setCount(count() + 1); + setMin((short) Integer.min(min(), temperature)); + setMax((short) Integer.max(max(), temperature)); + } + + private static boolean nameEquals( + long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2, long lastInputWord) { + boolean mismatch1 = inputWord1 != UNSAFE.getLong(statsAddr); + boolean mismatch2 = inputWord2 != UNSAFE.getLong(statsAddr + Long.BYTES); + if (len <= 2 * Long.BYTES) { + return !(mismatch1 | mismatch2); + } + int i = 2 * Long.BYTES; + for (; i <= len - Long.BYTES; i += Long.BYTES) { + if (UNSAFE.getLong(inputAddr + i) != UNSAFE.getLong(statsAddr + i)) { + return false; + } + } + return i == len || lastInputWord == UNSAFE.getLong(statsAddr + i); + } } private static void mergeSortAndPrint(StationStats[][] results) { @@ -556,4 +476,34 @@ public class CalculateAverage_mtopolnik { } System.out.println('}'); } + + static class StationStats implements Comparable { + String name; + long sum; + int count; + int min; + int max; + + @Override + public String toString() { + return String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum / count) / 10.0, max / 10.0); + } + + @Override + public boolean equals(Object that) { + return that.getClass() == StationStats.class && ((StationStats) that).name.equals(this.name); + } + + @Override + public int compareTo(StationStats that) { + return name.compareTo(that.name); + } + } + + private static String longToString(long word) { + final ByteBuffer buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder()); + buf.clear(); + buf.putLong(word); + return new String(buf.array(), StandardCharsets.UTF_8); + } }