From 7e525c599236a24984b3b1f2b8fd7ec5dbb960bc Mon Sep 17 00:00:00 2001 From: Thomas Wuerthinger Date: Sun, 28 Jan 2024 17:59:57 +0100 Subject: [PATCH] Some fine tuning for thomaswue (#606) * Some fine tuning. * Process 2MB segments to make all threads finish at the same time. Process with 3 scanners in parallel in the same thread. --- prepare_thomaswue.sh | 2 +- .../onebrc/CalculateAverage_thomaswue.java | 405 +++++++++++------- 2 files changed, 247 insertions(+), 160 deletions(-) diff --git a/prepare_thomaswue.sh b/prepare_thomaswue.sh index 32616a9..10dc732 100755 --- a/prepare_thomaswue.sh +++ b/prepare_thomaswue.sh @@ -20,7 +20,7 @@ sdk use java 21.0.2-graal 1>&2 # ./mvnw clean verify removes target/ and will re-trigger native image creation. if [ ! -f target/CalculateAverage_thomaswue_image ]; then - NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_thomaswue\$Scanner" + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -H:-GenLoopSafepoints -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_thomaswue\$Scanner" # Use -H:MethodFilter=CalculateAverage_thomaswue.* -H:Dump=:2 -H:PrintGraph=Network for IdealGraphVisualizer graph dumping. native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_thomaswue_image dev.morling.onebrc.CalculateAverage_thomaswue fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java index 406c85d..c02a881 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java @@ -23,16 +23,17 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.*; +import java.util.concurrent.atomic.AtomicLong; import java.util.stream.IntStream; /** * Simple solution that memory maps the input file, then splits it into one segment per available core and uses * sun.misc.Unsafe to directly access the mapped memory. Uses a long at a time when checking for collision. *

- * Runs in 0.60s on my Intel i9-13900K + * Runs in 0.41s on my Intel i9-13900K * Perf stats: - * 34,716,719,245 cpu_core/cycles/ - * 40,776,530,892 cpu_atom/cycles/ + * 25,286,227,376 cpu_core/cycles/ + * 26,833,723,225 cpu_atom/cycles/ */ public class CalculateAverage_thomaswue { private static final String FILE = "./measurements.txt"; @@ -42,10 +43,11 @@ public class CalculateAverage_thomaswue { // Holding the current result for a single city. private static class Result { long lastNameLong, secondLastNameLong; - long[] name; - int count; - short min, max; + long min, max; long sum; + int count; + long[] name; + String nameAsString; private Result() { this.min = MAX_TEMP; @@ -73,36 +75,59 @@ public class CalculateAverage_thomaswue { } public String calcName() { - ByteBuffer bb = ByteBuffer.allocate(name.length * Long.BYTES).order(ByteOrder.nativeOrder()); - bb.asLongBuffer().put(name); - byte[] array = bb.array(); - int i = 0; - while (array[i++] != ';') - ; - return new String(array, 0, i - 1, StandardCharsets.UTF_8); + if (nameAsString == null) { + ByteBuffer bb = ByteBuffer.allocate(name.length * Long.BYTES).order(ByteOrder.nativeOrder()); + bb.asLongBuffer().put(name); + byte[] array = bb.array(); + int i = 0; + while (array[i++] != ';') + ; + nameAsString = new String(array, 0, i - 1, StandardCharsets.UTF_8); + } + return nameAsString; } } - public static void main(String[] args) throws IOException { + public static void main(String[] args) throws IOException, InterruptedException { if (args.length == 0 || !("--worker".equals(args[0]))) { spawnWorker(); return; } // Calculate input segments. - int numberOfChunks = Runtime.getRuntime().availableProcessors(); - long[] chunks = getSegments(numberOfChunks); + int numberOfWorkers = Runtime.getRuntime().availableProcessors(); + final AtomicLong cursor = new AtomicLong(); + final long fileEnd; + final long fileStart; + + try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { + long fileSize = fileChannel.size(); + fileStart = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, java.lang.foreign.Arena.global()).address(); + cursor.set(fileStart); + fileEnd = fileStart + fileSize; + } // Parallel processing of segments. - List> allResults = IntStream.range(0, chunks.length - 1).mapToObj(chunkIndex -> parseLoop(chunks[chunkIndex], chunks[chunkIndex + 1])) - .map(resultArray -> { - List results = new ArrayList<>(); - for (Result r : resultArray) { - if (r != null) { - results.add(r); - } + Thread[] threads = new Thread[numberOfWorkers]; + List[] allResults = new List[numberOfWorkers]; + for (int i = 0; i < threads.length; ++i) { + final int index = i; + threads[i] = new Thread(() -> { + Result[] resultArray = parseLoop(cursor, fileEnd, fileStart); + List results = new ArrayList<>(500); + for (Result r : resultArray) { + if (r != null) { + r.calcName(); + results.add(r); } - return results; - }).parallel().toList(); + } + allResults[index] = results; + }); + threads[i].start(); + } + + for (Thread thread : threads) { + thread.join(); + } // Final output. System.out.println(accumulateResults(allResults)); @@ -115,17 +140,12 @@ public class CalculateAverage_thomaswue { info.command().ifPresent(workerCommand::add); info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args))); workerCommand.add("--worker"); - new ProcessBuilder() - .command(workerCommand) - .inheritIO() - .redirectOutput(ProcessBuilder.Redirect.PIPE) - .start() - .getInputStream() - .transferTo(System.out); + new ProcessBuilder().command(workerCommand).inheritIO().redirectOutput(ProcessBuilder.Redirect.PIPE) + .start().getInputStream().transferTo(System.out); } // Accumulate results sequentially for simplicity. - private static TreeMap accumulateResults(List> allResults) { + private static TreeMap accumulateResults(List[] allResults) { TreeMap result = new TreeMap<>(); for (List resultArr : allResults) { for (Result r : resultArr) { @@ -139,141 +159,220 @@ public class CalculateAverage_thomaswue { return result; } - // Main parse loop. - private static Result[] parseLoop(long chunkStart, long chunkEnd) { - Result[] results = new Result[1 << 17]; - Scanner scanner = new Scanner(chunkStart, chunkEnd); - long word = scanner.getLong(); - long pos = findDelimiter(word); - while (scanner.hasNext()) { - long nameAddress = scanner.pos(); - long hash = 0; + private static Result findResult(long initialWord, long initialPos, Scanner scanner, Result[] results) { - // Search for ';', one long at a time. + Result existingResult; + long word = initialWord; + long pos = initialPos; + long hash; + long nameAddress = scanner.pos(); + + // Search for ';', one long at a time. + if (pos != 0) { + pos = Long.numberOfTrailingZeros(pos) >>> 3; + scanner.add(pos); + word = mask(word, pos); + hash = word; + + int index = hashToIndex(hash, results); + existingResult = results[index]; + + if (existingResult != null && existingResult.lastNameLong == word) { + return existingResult; + } + else { + scanner.setPos(nameAddress + pos); + } + } + else { + scanner.add(8); + hash = word; + long prevWord = word; + word = scanner.getLong(); + pos = findDelimiter(word); if (pos != 0) { pos = Long.numberOfTrailingZeros(pos) >>> 3; scanner.add(pos); word = mask(word, pos); - hash = word; + hash ^= word; + int index = hashToIndex(hash, results); + existingResult = results[index]; - int number = scanNumber(scanner); - long nextWord = scanner.getLong(); - long nextPos = findDelimiter(nextWord); - - Result existingResult = results[hashToIndex(hash, results)]; - if (existingResult != null && existingResult.lastNameLong == word) { - word = nextWord; - pos = nextPos; - record(existingResult, number); - continue; + if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { + return existingResult; + } + else { + scanner.setPos(nameAddress + pos + 8); } - - scanner.setPos(nameAddress + pos); } else { scanner.add(8); - hash = word; - long prevWord = word; - word = scanner.getLong(); - pos = findDelimiter(word); - if (pos != 0) { - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); - hash ^= word; - - Result existingResult = results[hashToIndex(hash, results)]; - if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { - int number = scanNumber(scanner); - word = scanner.getLong(); - pos = findDelimiter(word); - record(existingResult, number); - continue; + hash ^= word; + while (true) { + word = scanner.getLong(); + pos = findDelimiter(word); + if (pos != 0) { + pos = Long.numberOfTrailingZeros(pos) >>> 3; + scanner.add(pos); + word = mask(word, pos); + hash ^= word; + break; } - } - else { - scanner.add(8); - hash ^= word; - while (true) { - word = scanner.getLong(); - pos = findDelimiter(word); - if (pos != 0) { - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); - hash ^= word; - break; - } - else { - scanner.add(8); - hash ^= word; - } + else { + scanner.add(8); + hash ^= word; } } } - - // Save length of name for later. - int nameLength = (int) (scanner.pos() - nameAddress); - int number = scanNumber(scanner); - - // Final calculation for index into hash table. - int tableIndex = hashToIndex(hash, results); - outer: while (true) { - Result existingResult = results[tableIndex]; - if (existingResult == null) { - existingResult = newEntry(results, nameAddress, tableIndex, nameLength, scanner); - } - // Check for collision. - int i = 0; - int namePos = 0; - for (; i < nameLength + 1 - 8; i += 8) { - if (namePos >= existingResult.name.length || existingResult.name[namePos++] != scanner.getLongAt(nameAddress + i)) { - tableIndex = (tableIndex + 31) & (results.length - 1); - continue outer; - } - } - - int remainingShift = (64 - (nameLength + 1 - i) << 3); - if (((existingResult.lastNameLong ^ (scanner.getLongAt(nameAddress + i) << remainingShift)) == 0)) { - record(existingResult, number); - break; - } - else { - // Collision error, try next. - tableIndex = (tableIndex + 31) & (results.length - 1); - } - } - - word = scanner.getLong(); - pos = findDelimiter(word); } - return results; + + // Save length of name for later. + int nameLength = (int) (scanner.pos() - nameAddress); + + // Final calculation for index into hash table. + int tableIndex = hashToIndex(hash, results); + outer: while (true) { + existingResult = results[tableIndex]; + if (existingResult == null) { + existingResult = newEntry(results, nameAddress, tableIndex, nameLength, scanner); + } + // Check for collision. + int i = 0; + long[] name = existingResult.name; + for (; i < nameLength + 1 - 8; i += 8) { + if (scanner.getLongAt(i, name) != scanner.getLongAt(nameAddress + i)) { + tableIndex = (tableIndex + 31) & (results.length - 1); + continue outer; + } + } + + int remainingShift = (64 - (nameLength + 1 - i) << 3); + if (((existingResult.lastNameLong ^ (scanner.getLongAt(nameAddress + i) << remainingShift)) == 0)) { + break; + } + else { + // Collision error, try next. + tableIndex = (tableIndex + 31) & (results.length - 1); + } + } + return existingResult; } - private static int scanNumber(Scanner scanPtr) { + private static long nextNL(long prev) { + while (true) { + long currentWord = Scanner.UNSAFE.getLong(prev); + long pos = findNewLine(currentWord); + if (pos != 0) { + prev += Long.numberOfTrailingZeros(pos) >>> 3; + break; + } + else { + prev += 8; + } + } + return prev; + } + + private static final int SEGMENT_SIZE = 1024 * 1024 * 2; + + // Main parse loop. + private static Result[] parseLoop(AtomicLong counter, long fileEnd, long fileStart) { + Result[] results = new Result[1 << 17]; + + while (true) { + long current = counter.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE; + + if (current >= fileEnd) { + return results; + } + + long segmentEnd = nextNL(Math.min(fileEnd - 1, current + SEGMENT_SIZE)); + long segmentStart; + if (current == fileStart) { + segmentStart = current; + } + else { + segmentStart = nextNL(current) + 1; + } + + long dist = (segmentEnd - segmentStart) / 3; + long midPoint1 = nextNL(segmentStart + dist); + long midPoint2 = nextNL(segmentStart + dist + dist); + + Scanner scanner1 = new Scanner(segmentStart, midPoint1); + Scanner scanner2 = new Scanner(midPoint1 + 1, midPoint2); + Scanner scanner3 = new Scanner(midPoint2 + 1, segmentEnd); + while (true) { + if (!scanner1.hasNext()) { + break; + } + if (!scanner2.hasNext()) { + break; + } + if (!scanner3.hasNext()) { + break; + } + + long word1 = scanner1.getLong(); + long word2 = scanner2.getLong(); + long word3 = scanner3.getLong(); + long pos1 = findDelimiter(word1); + long pos2 = findDelimiter(word2); + long pos3 = findDelimiter(word3); + Result existingResult1 = findResult(word1, pos1, scanner1, results); + Result existingResult2 = findResult(word2, pos2, scanner2, results); + Result existingResult3 = findResult(word3, pos3, scanner3, results); + long number1 = scanNumber(scanner1); + long number2 = scanNumber(scanner2); + long number3 = scanNumber(scanner3); + record(existingResult1, number1); + record(existingResult2, number2); + record(existingResult3, number3); + } + + while (scanner1.hasNext()) { + long word = scanner1.getLong(); + long pos = findDelimiter(word); + record(findResult(word, pos, scanner1, results), scanNumber(scanner1)); + } + + while (scanner2.hasNext()) { + long word = scanner2.getLong(); + long pos = findDelimiter(word); + record(findResult(word, pos, scanner2, results), scanNumber(scanner2)); + } + + while (scanner3.hasNext()) { + long word = scanner3.getLong(); + long pos = findDelimiter(word); + record(findResult(word, pos, scanner3, results), scanNumber(scanner3)); + } + } + } + + private static long scanNumber(Scanner scanPtr) { scanPtr.add(1); long numberWord = scanPtr.getLong(); int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); - int number = convertIntoNumber(decimalSepPos, numberWord); + long number = convertIntoNumber(decimalSepPos, numberWord); scanPtr.add((decimalSepPos >>> 3) + 3); return number; } - private static void record(Result existingResult, int number) { + private static void record(Result existingResult, long number) { if (number < existingResult.min) { - existingResult.min = (short) number; + existingResult.min = number; } if (number > existingResult.max) { - existingResult.max = (short) number; + existingResult.max = number; } existingResult.sum += number; existingResult.count++; } private static int hashToIndex(long hash, Result[] results) { - int hashAsInt = (int) (hash ^ (hash >>> 28)); - int finalHash = (hashAsInt ^ (hashAsInt >>> 17)); - return (finalHash & (results.length - 1)); + long hashAsInt = hash ^ (hash >>> 37) ^ (hash >>> 17); + return (int) (hashAsInt & (results.length - 1)); } private static long mask(long word, long pos) { @@ -281,7 +380,7 @@ public class CalculateAverage_thomaswue { } // Special method to convert a number in the ascii number into an int without branches created by Quan Anh Mai. - private static int convertIntoNumber(int decimalSepPos, long numberWord) { + private static long convertIntoNumber(int decimalSepPos, long numberWord) { int shift = 28 - decimalSepPos; // signed is -1 if negative, 0 otherwise long signed = (~numberWord << 59) >> 63; @@ -292,8 +391,7 @@ public class CalculateAverage_thomaswue { // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) = // 0x000000UU00TTHH00 + 0x00UU00TTHH000000 * 10 + 0xUU00TTHH00000000 * 100 long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; - long value = (absValue ^ signed) - signed; - return (int) value; + return (absValue ^ signed) - signed; } private static long findDelimiter(long word) { @@ -302,6 +400,12 @@ public class CalculateAverage_thomaswue { return tmp; } + private static long findNewLine(long word) { + long input = word ^ 0x0A0A0A0A0A0A0A0AL; + long tmp = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; + return tmp; + } + private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner) { Result r = new Result(); results[hash] = r; @@ -324,27 +428,6 @@ public class CalculateAverage_thomaswue { return r; } - private static long[] getSegments(int numberOfChunks) throws IOException { - try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { - long fileSize = fileChannel.size(); - long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks; - long[] chunks = new long[numberOfChunks + 1]; - long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, java.lang.foreign.Arena.global()).address(); - chunks[0] = mappedAddress; - long endAddress = mappedAddress + fileSize; - Scanner s = new Scanner(mappedAddress, mappedAddress + fileSize); - for (int i = 1; i < numberOfChunks; ++i) { - long chunkAddress = mappedAddress + i * segmentSize; - // Align to first row start. - while (chunkAddress < endAddress && (s.getLongAt(chunkAddress++) & 0xFF) != '\n') - ; - chunks[i] = Math.min(chunkAddress, endAddress); - } - chunks[numberOfChunks] = endAddress; - return chunks; - } - } - private static class Scanner { private static final sun.misc.Unsafe UNSAFE = initUnsafe(); @@ -387,6 +470,10 @@ public class CalculateAverage_thomaswue { return UNSAFE.getLong(pos); } + long getLongAt(long pos, long[] array) { + return UNSAFE.getLong(array, pos + sun.misc.Unsafe.ARRAY_LONG_BASE_OFFSET); + } + void setPos(long l) { this.pos = l; }