From d9ab36a241e4f38e404b5fd5f92de86337dd459c Mon Sep 17 00:00:00 2001 From: Jaromir Hamala Date: Sun, 28 Jan 2024 11:34:28 +0100 Subject: [PATCH] jerrinot's improvement (#607) * some random changes with minimal, if any, effect * use munmap() trick credit: thomaswue * some smaller tweaks * use native image --- calculate_average_jerrinot.sh | 10 +- prepare_jerrinot.sh | 9 +- .../onebrc/CalculateAverage_jerrinot.java | 262 +++++++++++------- 3 files changed, 175 insertions(+), 106 deletions(-) diff --git a/calculate_average_jerrinot.sh b/calculate_average_jerrinot.sh index 8de06c3..7311723 100755 --- a/calculate_average_jerrinot.sh +++ b/calculate_average_jerrinot.sh @@ -17,5 +17,11 @@ # -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_mtopolnik::recordMeasurementAndAdvanceCursor" # -XX:InlineSmallCode=10000 -XX:-TieredCompilation -XX:CICompilerCount=2 -XX:CompileThreshold=1000\ -java -XX:+UseParallelGC --enable-preview \ - --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jerrinot +if [ -f target/CalculateAverage_jerrinot_image ]; then + echo "Picking up existing native image 'target/CalculateAverage_jerrinot_image', delete the file to select JVM mode." 1>&2 + target/CalculateAverage_jerrinot_image +else + JAVA_OPTS="--enable-preview" + echo "Choosing to run the app in JVM mode as no native image was found, use prepare_jerrinot.sh to generate." 1>&2 + java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jerrinot +fi diff --git a/prepare_jerrinot.sh b/prepare_jerrinot.sh index f83a3ff..c36cae3 100755 --- a/prepare_jerrinot.sh +++ b/prepare_jerrinot.sh @@ -16,4 +16,11 @@ # source "$HOME/.sdkman/bin/sdkman-init.sh" -sdk use java 21.0.1-graal 1>&2 +sdk use java 21.0.2-graal 1>&2 + +# ./mvnw clean verify removes target/ and will re-trigger native image creation. +if [ ! -f target/CalculateAverage_jerrinot_image ]; then + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_jerrinot" + # Use -H:MethodFilter=CalculateAverage_jerrinot.* -H:Dump=:2 -H:PrintGraph=Network for IdealGraphVisualizer graph dumping. + native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_jerrinot_image dev.morling.onebrc.CalculateAverage_jerrinot +fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java index 2492c0f..36e3182 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java @@ -18,6 +18,7 @@ package dev.morling.onebrc; import sun.misc.Unsafe; import java.io.File; +import java.io.IOException; import java.io.RandomAccessFile; import java.lang.foreign.Arena; import java.lang.reflect.Field; @@ -54,9 +55,29 @@ public class CalculateAverage_jerrinot { } public static void main(String[] args) throws Exception { + // credits for spawning new workers: thomaswue + if (args.length == 0 || !("--worker".equals(args[0]))) { + spawnWorker(); + return; + } calculate(); } + private static void spawnWorker() throws IOException { + ProcessHandle.Info info = ProcessHandle.current().info(); + ArrayList workerCommand = new ArrayList<>(); + info.command().ifPresent(workerCommand::add); + info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args))); + workerCommand.add("--worker"); + new ProcessBuilder() + .command(workerCommand) + .inheritIO() + .redirectOutput(ProcessBuilder.Redirect.PIPE) + .start() + .getInputStream() + .transferTo(System.out); + } + static void calculate() throws Exception { final File file = new File(MEASUREMENTS_TXT); final long length = file.length(); @@ -140,6 +161,7 @@ public class CalculateAverage_jerrinot { } sb.append('}'); System.out.println(sb); + System.out.close(); } public static int ceilPow2(int i) { @@ -187,7 +209,7 @@ public class CalculateAverage_jerrinot { private static final int SLOW_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * SLOW_MAP_ENTRY_SIZE_BYTES; private static final int FAST_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * FAST_MAP_ENTRY_SIZE_BYTES; private static final int SLOW_MAP_MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES; - private static final long MAP_MASK = MAPS_SLOT_COUNT - 1; + private static final int MAP_MASK = MAPS_SLOT_COUNT - 1; private long slowMap; private long slowMapNamesPtr; @@ -281,9 +303,9 @@ public class CalculateAverage_jerrinot { doOne(cursorC, endC); transferToHeap(); - UNSAFE.freeMemory(fastMap); - UNSAFE.freeMemory(slowMap); - UNSAFE.freeMemory(slowMapNamesLo); + // UNSAFE.freeMemory(fastMap); + // UNSAFE.freeMemory(slowMap); + // UNSAFE.freeMemory(slowMapNamesLo); } private void transferToHeap() { @@ -339,11 +361,11 @@ public class CalculateAverage_jerrinot { long mask = getDelimiterMask(currentWord); long firstWordMask = ((mask - 1) ^ mask) >>> 8; final long isMaskZeroA = ((mask | -mask) >>> 63) ^ 1; - long ext = -isMaskZeroA & 0xFF00_0000_0000_0000L; + long ext = -isMaskZeroA; firstWordMask |= ext; long maskedFirstWord = currentWord & firstWordMask; - long hash = hash(maskedFirstWord); + int hash = hash(maskedFirstWord); while (mask == 0) { cursor += 8; currentWord = UNSAFE.getLong(cursor); @@ -353,22 +375,22 @@ public class CalculateAverage_jerrinot { final long semicolon = cursor + (delimiterByte >> 3); final long maskedWord = currentWord & ((mask - 1) ^ mask) >>> 8; - long len = semicolon - start; - long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, (int) hash, maskedWord); + int len = (int) (semicolon - start); + long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, hash, maskedWord); long temperatureWord = UNSAFE.getLong(semicolon + 1); cursor = parseAndStoreTemperature(semicolon + 1, baseEntryPtr, temperatureWord); } } - private static long hash(long word1) { + private static int hash(long word) { // credit: mtopolnik long seed = 0x51_7c_c1_b7_27_22_0a_95L; int rotDist = 17; - - long hash = word1; + // + long hash = word; hash *= seed; hash = Long.rotateLeft(hash, rotDist); - return hash; + return (int) hash; } @Override @@ -382,69 +404,87 @@ public class CalculateAverage_jerrinot { UNSAFE.setMemory(slowMapNamesPtr, SLOW_MAP_MAP_NAMES_BYTES, (byte) 0); while (cursorA < endA && cursorB < endB && cursorC < endC) { + long currentWordA = UNSAFE.getLong(cursorA); + long currentWordB = UNSAFE.getLong(cursorB); + long currentWordC = UNSAFE.getLong(cursorC); + long startA = cursorA; long startB = cursorB; long startC = cursorC; - long currentWordA = UNSAFE.getLong(startA); - long currentWordB = UNSAFE.getLong(startB); - long currentWordC = UNSAFE.getLong(startC); - long maskA = getDelimiterMask(currentWordA); long maskB = getDelimiterMask(currentWordB); long maskC = getDelimiterMask(currentWordC); - long firstWordMaskA = (maskA ^ (maskA - 1)) >>> 8; - long firstWordMaskB = (maskB ^ (maskB - 1)) >>> 8; - long firstWordMaskC = (maskC ^ (maskC - 1)) >>> 8; + long maskComplementA = -maskA; + long maskComplementB = -maskB; + long maskComplementC = -maskC; - final long isMaskZeroA = ((maskA | -maskA) >>> 63) ^ 1; - final long isMaskZeroB = ((maskB | -maskB) >>> 63) ^ 1; - final long isMaskZeroC = ((maskC | -maskC) >>> 63) ^ 1; + long maskWithDelimiterA = (maskA ^ (maskA - 1)); + long maskWithDelimiterB = (maskB ^ (maskB - 1)); + long maskWithDelimiterC = (maskC ^ (maskC - 1)); - long extA = -isMaskZeroA & 0xFF00_0000_0000_0000L; - long extB = -isMaskZeroB & 0xFF00_0000_0000_0000L; - long extC = -isMaskZeroC & 0xFF00_0000_0000_0000L; + long isMaskZeroA = (((maskA | maskComplementA) >>> 63) ^ 1); + long isMaskZeroB = (((maskB | maskComplementB) >>> 63) ^ 1); + long isMaskZeroC = (((maskC | maskComplementC) >>> 63) ^ 1); - firstWordMaskA |= extA; - firstWordMaskB |= extB; - firstWordMaskC |= extC; + cursorA += isMaskZeroA << 3; + cursorB += isMaskZeroB << 3; + cursorC += isMaskZeroC << 3; - long maskedFirstWordA = currentWordA & firstWordMaskA; - long maskedFirstWordB = currentWordB & firstWordMaskB; - long maskedFirstWordC = currentWordC & firstWordMaskC; + long nextWordA = UNSAFE.getLong(cursorA); + long nextWordB = UNSAFE.getLong(cursorB); + long nextWordC = UNSAFE.getLong(cursorC); - // assertMasks(isMaskZeroA, maskA); + long firstWordMaskA = maskWithDelimiterA >>> 8; + long firstWordMaskB = maskWithDelimiterB >>> 8; + long firstWordMaskC = maskWithDelimiterC >>> 8; - long hashA = hash(maskedFirstWordA); - long hashB = hash(maskedFirstWordB); - long hashC = hash(maskedFirstWordC); + long nextMaskA = getDelimiterMask(nextWordA); + long nextMaskB = getDelimiterMask(nextWordB); + long nextMaskC = getDelimiterMask(nextWordC); - cursorA += isMaskZeroA * 8; - cursorB += isMaskZeroB * 8; - cursorC += isMaskZeroC * 8; + boolean slowA = nextMaskA == 0; + boolean slowB = nextMaskB == 0; + boolean slowC = nextMaskC == 0; + boolean slowSome = (slowA || slowB || slowC); - currentWordA = UNSAFE.getLong(cursorA); - currentWordB = UNSAFE.getLong(cursorB); - currentWordC = UNSAFE.getLong(cursorC); + long extA = -isMaskZeroA; + long extB = -isMaskZeroB; + long extC = -isMaskZeroC; - maskA = getDelimiterMask(currentWordA); - while (maskA == 0) { - cursorA += 8; - currentWordA = UNSAFE.getLong(cursorA); - maskA = getDelimiterMask(currentWordA); - } - maskB = getDelimiterMask(currentWordB); - while (maskB == 0) { - cursorB += 8; - currentWordB = UNSAFE.getLong(cursorB); - maskB = getDelimiterMask(currentWordB); - } - maskC = getDelimiterMask(currentWordC); - while (maskC == 0) { - cursorC += 8; - currentWordC = UNSAFE.getLong(cursorC); - maskC = getDelimiterMask(currentWordC); + long maskedFirstWordA = (extA | firstWordMaskA) & currentWordA; + long maskedFirstWordB = (extB | firstWordMaskB) & currentWordB; + long maskedFirstWordC = (extC | firstWordMaskC) & currentWordC; + + int hashA = hash(maskedFirstWordA); + int hashB = hash(maskedFirstWordB); + int hashC = hash(maskedFirstWordC); + + currentWordA = nextWordA; + currentWordB = nextWordB; + currentWordC = nextWordC; + + maskA = nextMaskA; + maskB = nextMaskB; + maskC = nextMaskC; + if (slowSome) { + while (maskA == 0) { + cursorA += 8; + currentWordA = UNSAFE.getLong(cursorA); + maskA = getDelimiterMask(currentWordA); + } + + while (maskB == 0) { + cursorB += 8; + currentWordB = UNSAFE.getLong(cursorB); + maskB = getDelimiterMask(currentWordB); + } + while (maskC == 0) { + cursorC += 8; + currentWordC = UNSAFE.getLong(cursorC); + maskC = getDelimiterMask(currentWordC); + } } final int delimiterByteA = Long.numberOfTrailingZeros(maskA); @@ -458,40 +498,57 @@ public class CalculateAverage_jerrinot { long digitStartA = semicolonA + 1; long digitStartB = semicolonB + 1; long digitStartC = semicolonC + 1; + long temperatureWordA = UNSAFE.getLong(digitStartA); long temperatureWordB = UNSAFE.getLong(digitStartB); long temperatureWordC = UNSAFE.getLong(digitStartC); - final long maskedWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8; - final long maskedWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8; - final long maskedWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8; + long lastWordMaskA = ((maskA - 1) ^ maskA) >>> 8; + long lastWordMaskB = ((maskB - 1) ^ maskB) >>> 8; + long lastWordMaskC = ((maskC - 1) ^ maskC) >>> 8; - long lenA = semicolonA - startA; - long lenB = semicolonB - startB; - long lenC = semicolonC - startC; + final long maskedLastWordA = currentWordA & lastWordMaskA; + final long maskedLastWordB = currentWordB & lastWordMaskB; + final long maskedLastWordC = currentWordC & lastWordMaskC; + + int lenA = (int) (semicolonA - startA); + int lenB = (int) (semicolonB - startB); + int lenC = (int) (semicolonC - startC); + + int mapIndexA = hashA & MAP_MASK; + int mapIndexB = hashB & MAP_MASK; + int mapIndexC = hashC & MAP_MASK; long baseEntryPtrA; - if (lenA > 15) { - baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, (int) hashA, maskedWordA); - } - else { - baseEntryPtrA = getOrCreateEntryBaseOffsetFast(lenA, (int) hashA, maskedWordA, maskedFirstWordA); - } - long baseEntryPtrB; - if (lenB > 15) { - baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, (int) hashB, maskedWordB); - } - else { - baseEntryPtrB = getOrCreateEntryBaseOffsetFast(lenB, (int) hashB, maskedWordB, maskedFirstWordB); - } - long baseEntryPtrC; - if (lenC > 15) { - baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, (int) hashC, maskedWordC); + + if (slowSome) { + if (slowA) { + baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, hashA, maskedLastWordA); + } + else { + baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA); + } + + if (slowB) { + baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, hashB, maskedLastWordB); + } + else { + baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB); + } + + if (slowC) { + baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, hashC, maskedLastWordC); + } + else { + baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC); + } } else { - baseEntryPtrC = getOrCreateEntryBaseOffsetFast(lenC, (int) hashC, maskedWordC, maskedFirstWordC); + baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA); + baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB); + baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC); } cursorA = parseAndStoreTemperature(digitStartA, baseEntryPtrA, temperatureWordA); @@ -502,36 +559,35 @@ public class CalculateAverage_jerrinot { // System.out.println("Longest chain: " + longestChain); } - private long getOrCreateEntryBaseOffsetFast(long lenLong, int hash, long maskedLastWord, long maskedFirstWord) { - int lenA = (int) lenLong; - long mapIndexA = hash & MAP_MASK; + private long getOrCreateEntryBaseOffsetFast(int mapIndexA, int lenA, long maskedLastWord, long maskedFirstWord) { for (;;) { long basePtr = mapIndexA * FAST_MAP_ENTRY_SIZE_BYTES + fastMap; + long namePart1 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART1); + long namePart2 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART2); + if (namePart1 == maskedFirstWord && namePart2 == maskedLastWord) { + return basePtr; + } long lenPtr = basePtr + MAP_LEN_OFFSET; int len = UNSAFE.getInt(lenPtr); - if (len == lenA) { - long namePart1 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART1); - long namePart2 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART2); - if (namePart1 == maskedFirstWord && namePart2 == maskedLastWord) { - return basePtr; - } - } - else if (len == 0) { - UNSAFE.putInt(lenPtr, lenA); - // todo: this could be a single putLong() - UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE); - UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE); - UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART1, maskedFirstWord); - UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART2, maskedLastWord); - return basePtr; + if (len == 0) { + return newEntryFast(lenA, maskedLastWord, maskedFirstWord, lenPtr, basePtr); } mapIndexA = ++mapIndexA & MAP_MASK; } } - private long getOrCreateEntryBaseOffsetSlow(long lenLong, long startPtr, int hash, long maskedLastWord) { - long fullLen = lenLong & ~7L; - int lenA = (int) lenLong; + private static long newEntryFast(int lenA, long maskedLastWord, long maskedFirstWord, long lenPtr, long basePtr) { + UNSAFE.putInt(lenPtr, lenA); + // todo: this could be a single putLong() + UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE); + UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE); + UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART1, maskedFirstWord); + UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART2, maskedLastWord); + return basePtr; + } + + private long getOrCreateEntryBaseOffsetSlow(int lenA, long startPtr, int hash, long maskedLastWord) { + long fullLen = lenA & ~7L; long mapIndexA = hash & MAP_MASK; for (;;) { long basePtr = mapIndexA * SLOW_MAP_ENTRY_SIZE_BYTES + slowMap; @@ -550,7 +606,7 @@ public class CalculateAverage_jerrinot { UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE); UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE); UNSAFE.copyMemory(startPtr, slowMapNamesPtr, lenA); - long alignedLen = (lenLong & ~7L) + 8; + long alignedLen = (lenA & ~7L) + 8; slowMapNamesPtr += alignedLen; return basePtr; }