From a5ce4ba77184d669e67d9633ee20c466ad167742 Mon Sep 17 00:00:00 2001 From: Thomas Wuerthinger Date: Wed, 31 Jan 2024 09:34:15 +0100 Subject: [PATCH] Added comments to used flags, clean up code, final fine tuning. (#674) --- prepare_thomaswue.sh | 16 +- .../onebrc/CalculateAverage_thomaswue.java | 234 ++++++++---------- 2 files changed, 114 insertions(+), 136 deletions(-) diff --git a/prepare_thomaswue.sh b/prepare_thomaswue.sh index da0a591..3e75233 100755 --- a/prepare_thomaswue.sh +++ b/prepare_thomaswue.sh @@ -20,7 +20,19 @@ sdk use java 21.0.2-graal 1>&2 # ./mvnw clean verify removes target/ and will re-trigger native image creation. if [ ! -f target/CalculateAverage_thomaswue_image ]; then - NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -H:TuneInlinerExploration=1 -march=native --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_thomaswue\$Scanner" - # Use -H:MethodFilter=CalculateAverage_thomaswue.* -H:Dump=:2 -H:PrintGraph=Network for IdealGraphVisualizer graph dumping. + + # Performance tuning flags, optimization level 3, maximum inlining exploration, and compile for the architecture where the native image is generated. + NATIVE_IMAGE_OPTS="-O3 -H:TuneInlinerExploration=1 -march=native" + + # Need to enable preview for accessing the raw address of the foreign memory access API. + # Initializing the Scanner to make sure the unsafe access object is known as a non-null compile time constant. + NATIVE_IMAGE_OPTS="$NATIVE_IMAGE_OPTS --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_thomaswue\$Scanner" + + # There is no need for garbage collection and therefore also no safepoints required. + NATIVE_IMAGE_OPTS="$NATIVE_IMAGE_OPTS --gc=epsilon -H:-GenLoopSafepoints" + + # Uncomment the following line for outputting the compiler graph to the IdealGraphVisualizer + # NATIVE_IMAGE_OPTS="$NATIVE_IMAGE_OPTS -H:MethodFilter=CalculateAverage_thomaswue.* -H:Dump=:2 -H:PrintGraph=Network" + native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_thomaswue_image dev.morling.onebrc.CalculateAverage_thomaswue fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java index 9b21f91..dc4df0c 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java @@ -27,9 +27,7 @@ import java.util.concurrent.atomic.AtomicLong; * split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread. * Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in * the end. - * - * Runs in 0.40s on an Intel i9-13900K. - * + * Runs in 0.39s on an Intel i9-13900K. * Credit: * Quan Anh Mai for branchless number parsing code * Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea @@ -103,49 +101,111 @@ public class CalculateAverage_thomaswue { return result; } - private static Result findResult(long initialWord, long initialPos, Scanner scanner, Result[] results, List collectedResults) { + private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart, List collectedResults) { + Result[] results = new Result[HASH_TABLE_SIZE]; + while (true) { + long current = counter.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE; + if (current >= fileEnd) { + return; + } + + long segmentEnd = nextNewLine(Math.min(fileEnd - 1, current + SEGMENT_SIZE)); + long segmentStart; + if (current == fileStart) { + segmentStart = current; + } + else { + segmentStart = nextNewLine(current) + 1; + } + + long dist = (segmentEnd - segmentStart) / 3; + long midPoint1 = nextNewLine(segmentStart + dist); + long midPoint2 = nextNewLine(segmentStart + dist + dist); + + Scanner scanner1 = new Scanner(segmentStart, midPoint1); + Scanner scanner2 = new Scanner(midPoint1 + 1, midPoint2); + Scanner scanner3 = new Scanner(midPoint2 + 1, segmentEnd); + while (true) { + if (!scanner1.hasNext()) { + break; + } + if (!scanner2.hasNext()) { + break; + } + if (!scanner3.hasNext()) { + break; + } + long word1 = scanner1.getLong(); + long word2 = scanner2.getLong(); + long word3 = scanner3.getLong(); + long delimiterMask1 = findDelimiter(word1); + long delimiterMask2 = findDelimiter(word2); + long delimiterMask3 = findDelimiter(word3); + Result existingResult1 = findResult(word1, delimiterMask1, scanner1, results, collectedResults); + Result existingResult2 = findResult(word2, delimiterMask2, scanner2, results, collectedResults); + Result existingResult3 = findResult(word3, delimiterMask3, scanner3, results, collectedResults); + long number1 = scanNumber(scanner1); + long number2 = scanNumber(scanner2); + long number3 = scanNumber(scanner3); + record(existingResult1, number1); + record(existingResult2, number2); + record(existingResult3, number3); + } + + while (scanner1.hasNext()) { + long word = scanner1.getLong(); + long pos = findDelimiter(word); + record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1)); + } + while (scanner2.hasNext()) { + long word = scanner2.getLong(); + long pos = findDelimiter(word); + record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2)); + } + while (scanner3.hasNext()) { + long word = scanner3.getLong(); + long pos = findDelimiter(word); + record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3)); + } + } + } + + private static Result findResult(long initialWord, long initialDelimiterMask, Scanner scanner, Result[] results, List collectedResults) { Result existingResult; long word = initialWord; - long pos = initialPos; + long delimiterMask = initialDelimiterMask; long hash; long nameAddress = scanner.pos(); // Search for ';', one long at a time. There are two common cases that a specially treated: // (b) the ';' is found in the first 16 bytes - if (pos != 0) { + if (delimiterMask != 0) { // Special case for when the ';' is found in the first 8 bytes. - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); + int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); + word = (word << (63 - trailingZeros)); + scanner.add(trailingZeros >>> 3); hash = word; - - int index = hashToIndex(hash, results); - existingResult = results[index]; - + existingResult = results[hashToIndex(hash, results)]; if (existingResult != null && existingResult.lastNameLong == word) { return existingResult; } - scanner.setPos(nameAddress + pos); } else { // Special case for when the ';' is found in bytes 9-16. - scanner.add(8); hash = word; long prevWord = word; + scanner.add(8); word = scanner.getLong(); - pos = findDelimiter(word); - if (pos != 0) { - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); + delimiterMask = findDelimiter(word); + if (delimiterMask != 0) { + int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); + word = (word << (63 - trailingZeros)); + scanner.add(trailingZeros >>> 3); hash ^= word; - int index = hashToIndex(hash, results); - existingResult = results[index]; - + existingResult = results[hashToIndex(hash, results)]; if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { return existingResult; } - scanner.setPos(nameAddress + pos + 8); } else { // Slow-path for when the ';' could not be found in the first 16 bytes. @@ -153,11 +213,11 @@ public class CalculateAverage_thomaswue { hash ^= word; while (true) { word = scanner.getLong(); - pos = findDelimiter(word); - if (pos != 0) { - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); + delimiterMask = findDelimiter(word); + if (delimiterMask != 0) { + int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); + word = (word << (63 - trailingZeros)); + scanner.add(trailingZeros >>> 3); hash ^= word; break; } @@ -204,7 +264,8 @@ public class CalculateAverage_thomaswue { private static long nextNewLine(long prev) { while (true) { long currentWord = Scanner.UNSAFE.getLong(prev); - long pos = findNewLine(currentWord); + long input = currentWord ^ 0x0A0A0A0A0A0A0A0AL; + long pos = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; if (pos != 0) { prev += Long.numberOfTrailingZeros(pos) >>> 3; break; @@ -216,87 +277,11 @@ public class CalculateAverage_thomaswue { return prev; } - // Main parse loop. - private static Result[] parseLoop(AtomicLong counter, long fileEnd, long fileStart, List collectedResults) { - Result[] results = new Result[HASH_TABLE_SIZE]; - - while (true) { - long current = counter.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE; - - if (current >= fileEnd) { - return results; - } - - long segmentEnd = nextNewLine(Math.min(fileEnd - 1, current + SEGMENT_SIZE)); - long segmentStart; - if (current == fileStart) { - segmentStart = current; - } - else { - segmentStart = nextNewLine(current) + 1; - } - - long dist = (segmentEnd - segmentStart) / 3; - long midPoint1 = nextNewLine(segmentStart + dist); - long midPoint2 = nextNewLine(segmentStart + dist + dist); - - Scanner scanner1 = new Scanner(segmentStart, midPoint1); - Scanner scanner2 = new Scanner(midPoint1 + 1, midPoint2); - Scanner scanner3 = new Scanner(midPoint2 + 1, segmentEnd); - while (true) { - if (!scanner1.hasNext()) { - break; - } - if (!scanner2.hasNext()) { - break; - } - if (!scanner3.hasNext()) { - break; - } - - long word1 = scanner1.getLong(); - long word2 = scanner2.getLong(); - long word3 = scanner3.getLong(); - long pos1 = findDelimiter(word1); - long pos2 = findDelimiter(word2); - long pos3 = findDelimiter(word3); - Result existingResult1 = findResult(word1, pos1, scanner1, results, collectedResults); - Result existingResult2 = findResult(word2, pos2, scanner2, results, collectedResults); - Result existingResult3 = findResult(word3, pos3, scanner3, results, collectedResults); - long number1 = scanNumber(scanner1); - long number2 = scanNumber(scanner2); - long number3 = scanNumber(scanner3); - record(existingResult1, number1); - record(existingResult2, number2); - record(existingResult3, number3); - } - - while (scanner1.hasNext()) { - long word = scanner1.getLong(); - long pos = findDelimiter(word); - record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1)); - } - - while (scanner2.hasNext()) { - long word = scanner2.getLong(); - long pos = findDelimiter(word); - record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2)); - } - - while (scanner3.hasNext()) { - long word = scanner3.getLong(); - long pos = findDelimiter(word); - record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3)); - } - } - } - private static long scanNumber(Scanner scanPtr) { - scanPtr.add(1); - long numberWord = scanPtr.getLong(); - int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); + long numberWord = scanPtr.getLongAt(scanPtr.pos() + 1); + int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000L); long number = convertIntoNumber(decimalSepPos, numberWord); - scanPtr.add((decimalSepPos >>> 3) + 3); + scanPtr.add((decimalSepPos >>> 3) + 4); return number; } @@ -316,10 +301,6 @@ public class CalculateAverage_thomaswue { return (int) (hashAsInt & (results.length - 1)); } - private static long mask(long word, long pos) { - return (word << ((7 - pos) << 3)); - } - // Special method to convert a number in the ascii number into an int without branches created by Quan Anh Mai. private static long convertIntoNumber(int decimalSepPos, long numberWord) { int shift = 28 - decimalSepPos; @@ -337,14 +318,7 @@ public class CalculateAverage_thomaswue { private static long findDelimiter(long word) { long input = word ^ 0x3B3B3B3B3B3B3B3BL; - long tmp = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; - return tmp; - } - - private static long findNewLine(long word) { - long input = word ^ 0x0A0A0A0A0A0A0A0AL; - long tmp = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; - return tmp; + return (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; } private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List collectedResults) { @@ -357,14 +331,13 @@ public class CalculateAverage_thomaswue { r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8); } int remainingShift = (64 - (nameLength + 1 - i) << 3); - long lastWord = (scanner.getLongAt(nameAddress + i) << remainingShift); - r.lastNameLong = lastWord; + r.lastNameLong = (scanner.getLongAt(nameAddress + i) << remainingShift); r.nameAddress = nameAddress; collectedResults.add(r); return r; } - private static class Result { + private static final class Result { long lastNameLong, secondLastNameLong; short min, max; int count; @@ -409,9 +382,10 @@ public class CalculateAverage_thomaswue { } } - private static class Scanner { + private static final class Scanner { private static final sun.misc.Unsafe UNSAFE = initUnsafe(); - private long pos, end; + private long pos; + private final long end; private static sun.misc.Unsafe initUnsafe() { try { @@ -452,13 +426,5 @@ public class CalculateAverage_thomaswue { byte getByteAt(long pos) { return UNSAFE.getByte(pos); } - - long getLongAt(long pos, long[] array) { - return UNSAFE.getLong(array, pos + sun.misc.Unsafe.ARRAY_LONG_BASE_OFFSET); - } - - void setPos(long l) { - this.pos = l; - } } } \ No newline at end of file