From 036f9a01b18ef2d35dac6bbceec46b1ccbfe4f2b Mon Sep 17 00:00:00 2001 From: Thomas Wuerthinger Date: Mon, 29 Jan 2024 22:19:23 +0100 Subject: [PATCH] Clean up, fine tuning, credit section for thomaswue (#646) * Some clean up, fine tuning, removing non-supported options, added credit section and additional comments. * Put license header year back to 2023 to pass checks. * Remove static linking (as it requires some more setup on the target machine). --- prepare_thomaswue.sh | 2 +- .../onebrc/CalculateAverage_thomaswue.java | 267 ++++++++---------- 2 files changed, 126 insertions(+), 143 deletions(-) diff --git a/prepare_thomaswue.sh b/prepare_thomaswue.sh index 10dc732..da0a591 100755 --- a/prepare_thomaswue.sh +++ b/prepare_thomaswue.sh @@ -20,7 +20,7 @@ sdk use java 21.0.2-graal 1>&2 # ./mvnw clean verify removes target/ and will re-trigger native image creation. if [ ! -f target/CalculateAverage_thomaswue_image ]; then - NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -H:-GenLoopSafepoints -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_thomaswue\$Scanner" + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -H:TuneInlinerExploration=1 -march=native --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_thomaswue\$Scanner" # Use -H:MethodFilter=CalculateAverage_thomaswue.* -H:Dump=:2 -H:PrintGraph=Network for IdealGraphVisualizer graph dumping. native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_thomaswue_image dev.morling.onebrc.CalculateAverage_thomaswue fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java index c02a881..9b21f91 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java @@ -16,122 +16,68 @@ package dev.morling.onebrc; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.nio.channels.FileChannel; -import java.nio.charset.StandardCharsets; -import java.nio.file.Path; -import java.nio.file.StandardOpenOption; import java.util.*; import java.util.concurrent.atomic.AtomicLong; -import java.util.stream.IntStream; /** - * Simple solution that memory maps the input file, then splits it into one segment per available core and uses - * sun.misc.Unsafe to directly access the mapped memory. Uses a long at a time when checking for collision. - *

- * Runs in 0.41s on my Intel i9-13900K - * Perf stats: - * 25,286,227,376 cpu_core/cycles/ - * 26,833,723,225 cpu_atom/cycles/ + * The solution starts a child worker process for the actual work such that clean up of the memory mapping can occur + * while the main process already returns with the result. The worker then memory maps the input file, creates a worker + * thread per available core, and then processes segments of size {@link #SEGMENT_SIZE} at a time. The segments are + * split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread. + * Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in + * the end. + * + * Runs in 0.40s on an Intel i9-13900K. + * + * Credit: + * Quan Anh Mai for branchless number parsing code + * Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea + * Artsiom Korzun for showing the benefits of work stealing at 2MB segments instead of equal split between workers */ public class CalculateAverage_thomaswue { private static final String FILE = "./measurements.txt"; private static final int MIN_TEMP = -999; private static final int MAX_TEMP = 999; - - // Holding the current result for a single city. - private static class Result { - long lastNameLong, secondLastNameLong; - long min, max; - long sum; - int count; - long[] name; - String nameAsString; - - private Result() { - this.min = MAX_TEMP; - this.max = MIN_TEMP; - } - - public String toString() { - return round(((double) min) / 10.0) + "/" + round((((double) sum) / 10.0) / count) + "/" + round(((double) max) / 10.0); - } - - private static double round(double value) { - return Math.round(value * 10.0) / 10.0; - } - - // Accumulate another result into this one. - private void add(Result other) { - if (other.min < min) { - min = other.min; - } - if (other.max > max) { - max = other.max; - } - sum += other.sum; - count += other.count; - } - - public String calcName() { - if (nameAsString == null) { - ByteBuffer bb = ByteBuffer.allocate(name.length * Long.BYTES).order(ByteOrder.nativeOrder()); - bb.asLongBuffer().put(name); - byte[] array = bb.array(); - int i = 0; - while (array[i++] != ';') - ; - nameAsString = new String(array, 0, i - 1, StandardCharsets.UTF_8); - } - return nameAsString; - } - } + private static final int MAX_NAME_LENGTH = 100; + private static final int MAX_CITIES = 10000; + private static final int SEGMENT_SIZE = 1 << 21; + private static final int HASH_TABLE_SIZE = 1 << 17; public static void main(String[] args) throws IOException, InterruptedException { + // Start worker subprocess if this process is not the worker. if (args.length == 0 || !("--worker".equals(args[0]))) { spawnWorker(); return; } - // Calculate input segments. + int numberOfWorkers = Runtime.getRuntime().availableProcessors(); - final AtomicLong cursor = new AtomicLong(); - final long fileEnd; - final long fileStart; - - try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { + try (var fileChannel = FileChannel.open(java.nio.file.Path.of(FILE), java.nio.file.StandardOpenOption.READ)) { long fileSize = fileChannel.size(); - fileStart = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, java.lang.foreign.Arena.global()).address(); - cursor.set(fileStart); - fileEnd = fileStart + fileSize; - } + final long fileStart = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, java.lang.foreign.Arena.global()).address(); + final long fileEnd = fileStart + fileSize; + final AtomicLong cursor = new AtomicLong(fileStart); - // Parallel processing of segments. - Thread[] threads = new Thread[numberOfWorkers]; - List[] allResults = new List[numberOfWorkers]; - for (int i = 0; i < threads.length; ++i) { - final int index = i; - threads[i] = new Thread(() -> { - Result[] resultArray = parseLoop(cursor, fileEnd, fileStart); - List results = new ArrayList<>(500); - for (Result r : resultArray) { - if (r != null) { - r.calcName(); - results.add(r); - } - } - allResults[index] = results; - }); - threads[i].start(); - } + // Parallel processing of segments. + Thread[] threads = new Thread[numberOfWorkers]; + List[] allResults = new List[numberOfWorkers]; + for (int i = 0; i < threads.length; ++i) { + final int index = i; + threads[i] = new Thread(() -> { + List results = new ArrayList<>(MAX_CITIES); + parseLoop(cursor, fileEnd, fileStart, results); + allResults[index] = results; + }); + threads[i].start(); + } + for (Thread thread : threads) { + thread.join(); + } - for (Thread thread : threads) { - thread.join(); + // Final output. + System.out.println(accumulateResults(allResults)); + System.out.close(); } - - // Final output. - System.out.println(accumulateResults(allResults)); - System.out.close(); } private static void spawnWorker() throws IOException { @@ -144,31 +90,30 @@ public class CalculateAverage_thomaswue { .start().getInputStream().transferTo(System.out); } - // Accumulate results sequentially for simplicity. private static TreeMap accumulateResults(List[] allResults) { TreeMap result = new TreeMap<>(); for (List resultArr : allResults) { for (Result r : resultArr) { - String name = r.calcName(); - Result current = result.putIfAbsent(name, r); + Result current = result.putIfAbsent(r.calcName(), r); if (current != null) { - current.add(r); + current.accumulate(r); } } } return result; } - private static Result findResult(long initialWord, long initialPos, Scanner scanner, Result[] results) { - + private static Result findResult(long initialWord, long initialPos, Scanner scanner, Result[] results, List collectedResults) { Result existingResult; long word = initialWord; long pos = initialPos; long hash; long nameAddress = scanner.pos(); - // Search for ';', one long at a time. + // Search for ';', one long at a time. There are two common cases that a specially treated: + // (b) the ';' is found in the first 16 bytes if (pos != 0) { + // Special case for when the ';' is found in the first 8 bytes. pos = Long.numberOfTrailingZeros(pos) >>> 3; scanner.add(pos); word = mask(word, pos); @@ -180,11 +125,10 @@ public class CalculateAverage_thomaswue { if (existingResult != null && existingResult.lastNameLong == word) { return existingResult; } - else { - scanner.setPos(nameAddress + pos); - } + scanner.setPos(nameAddress + pos); } else { + // Special case for when the ';' is found in bytes 9-16. scanner.add(8); hash = word; long prevWord = word; @@ -201,11 +145,10 @@ public class CalculateAverage_thomaswue { if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { return existingResult; } - else { - scanner.setPos(nameAddress + pos + 8); - } + scanner.setPos(nameAddress + pos + 8); } else { + // Slow-path for when the ';' could not be found in the first 16 bytes. scanner.add(8); hash ^= word; while (true) { @@ -234,20 +177,20 @@ public class CalculateAverage_thomaswue { outer: while (true) { existingResult = results[tableIndex]; if (existingResult == null) { - existingResult = newEntry(results, nameAddress, tableIndex, nameLength, scanner); + existingResult = newEntry(results, nameAddress, tableIndex, nameLength, scanner, collectedResults); } // Check for collision. int i = 0; - long[] name = existingResult.name; for (; i < nameLength + 1 - 8; i += 8) { - if (scanner.getLongAt(i, name) != scanner.getLongAt(nameAddress + i)) { + if (scanner.getLongAt(existingResult.nameAddress + i) != scanner.getLongAt(nameAddress + i)) { + // Collision error, try next. tableIndex = (tableIndex + 31) & (results.length - 1); continue outer; } } int remainingShift = (64 - (nameLength + 1 - i) << 3); - if (((existingResult.lastNameLong ^ (scanner.getLongAt(nameAddress + i) << remainingShift)) == 0)) { + if (existingResult.lastNameLong == (scanner.getLongAt(nameAddress + i) << remainingShift)) { break; } else { @@ -258,7 +201,7 @@ public class CalculateAverage_thomaswue { return existingResult; } - private static long nextNL(long prev) { + private static long nextNewLine(long prev) { while (true) { long currentWord = Scanner.UNSAFE.getLong(prev); long pos = findNewLine(currentWord); @@ -273,11 +216,9 @@ public class CalculateAverage_thomaswue { return prev; } - private static final int SEGMENT_SIZE = 1024 * 1024 * 2; - // Main parse loop. - private static Result[] parseLoop(AtomicLong counter, long fileEnd, long fileStart) { - Result[] results = new Result[1 << 17]; + private static Result[] parseLoop(AtomicLong counter, long fileEnd, long fileStart, List collectedResults) { + Result[] results = new Result[HASH_TABLE_SIZE]; while (true) { long current = counter.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE; @@ -286,18 +227,18 @@ public class CalculateAverage_thomaswue { return results; } - long segmentEnd = nextNL(Math.min(fileEnd - 1, current + SEGMENT_SIZE)); + long segmentEnd = nextNewLine(Math.min(fileEnd - 1, current + SEGMENT_SIZE)); long segmentStart; if (current == fileStart) { segmentStart = current; } else { - segmentStart = nextNL(current) + 1; + segmentStart = nextNewLine(current) + 1; } long dist = (segmentEnd - segmentStart) / 3; - long midPoint1 = nextNL(segmentStart + dist); - long midPoint2 = nextNL(segmentStart + dist + dist); + long midPoint1 = nextNewLine(segmentStart + dist); + long midPoint2 = nextNewLine(segmentStart + dist + dist); Scanner scanner1 = new Scanner(segmentStart, midPoint1); Scanner scanner2 = new Scanner(midPoint1 + 1, midPoint2); @@ -319,9 +260,9 @@ public class CalculateAverage_thomaswue { long pos1 = findDelimiter(word1); long pos2 = findDelimiter(word2); long pos3 = findDelimiter(word3); - Result existingResult1 = findResult(word1, pos1, scanner1, results); - Result existingResult2 = findResult(word2, pos2, scanner2, results); - Result existingResult3 = findResult(word3, pos3, scanner3, results); + Result existingResult1 = findResult(word1, pos1, scanner1, results, collectedResults); + Result existingResult2 = findResult(word2, pos2, scanner2, results, collectedResults); + Result existingResult3 = findResult(word3, pos3, scanner3, results, collectedResults); long number1 = scanNumber(scanner1); long number2 = scanNumber(scanner2); long number3 = scanNumber(scanner3); @@ -333,19 +274,19 @@ public class CalculateAverage_thomaswue { while (scanner1.hasNext()) { long word = scanner1.getLong(); long pos = findDelimiter(word); - record(findResult(word, pos, scanner1, results), scanNumber(scanner1)); + record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1)); } while (scanner2.hasNext()) { long word = scanner2.getLong(); long pos = findDelimiter(word); - record(findResult(word, pos, scanner2, results), scanNumber(scanner2)); + record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2)); } while (scanner3.hasNext()) { long word = scanner3.getLong(); long pos = findDelimiter(word); - record(findResult(word, pos, scanner3, results), scanNumber(scanner3)); + record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3)); } } } @@ -361,10 +302,10 @@ public class CalculateAverage_thomaswue { private static void record(Result existingResult, long number) { if (number < existingResult.min) { - existingResult.min = number; + existingResult.min = (short) number; } if (number > existingResult.max) { - existingResult.max = number; + existingResult.max = (short) number; } existingResult.sum += number; existingResult.count++; @@ -406,31 +347,71 @@ public class CalculateAverage_thomaswue { return tmp; } - private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner) { + private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List collectedResults) { Result r = new Result(); results[hash] = r; - long[] name = new long[(nameLength / Long.BYTES) + 1]; - int pos = 0; int i = 0; for (; i < nameLength + 1 - Long.BYTES; i += Long.BYTES) { - name[pos++] = scanner.getLongAt(nameAddress + i); } - - if (pos > 0) { - r.secondLastNameLong = name[pos - 1]; + if (nameLength + 1 > 8) { + r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8); } - int remainingShift = (64 - (nameLength + 1 - i) << 3); long lastWord = (scanner.getLongAt(nameAddress + i) << remainingShift); r.lastNameLong = lastWord; - name[pos] = lastWord >> remainingShift; - r.name = name; + r.nameAddress = nameAddress; + collectedResults.add(r); return r; } - private static class Scanner { + private static class Result { + long lastNameLong, secondLastNameLong; + short min, max; + int count; + long sum; + long nameAddress; + private Result() { + this.min = MAX_TEMP; + this.max = MIN_TEMP; + } + + public String toString() { + return round(((double) min) / 10.0) + "/" + round((((double) sum) / 10.0) / count) + "/" + round(((double) max) / 10.0); + } + + private static double round(double value) { + return Math.round(value * 10.0) / 10.0; + } + + private void accumulate(Result other) { + if (other.min < min) { + min = other.min; + } + if (other.max > max) { + max = other.max; + } + sum += other.sum; + count += other.count; + } + + public String calcName() { + Scanner scanner = new Scanner(nameAddress, nameAddress + MAX_NAME_LENGTH + 1); + int nameLength = 0; + while (scanner.getByteAt(nameAddress + nameLength) != ';') { + nameLength++; + } + byte[] array = new byte[nameLength]; + for (int i = 0; i < nameLength; ++i) { + array[i] = scanner.getByteAt(nameAddress + i); + } + return new String(array, java.nio.charset.StandardCharsets.UTF_8); + } + } + + private static class Scanner { private static final sun.misc.Unsafe UNSAFE = initUnsafe(); + private long pos, end; private static sun.misc.Unsafe initUnsafe() { try { @@ -443,8 +424,6 @@ public class CalculateAverage_thomaswue { } } - long pos, end; - public Scanner(long start, long end) { this.pos = start; this.end = end; @@ -470,6 +449,10 @@ public class CalculateAverage_thomaswue { return UNSAFE.getLong(pos); } + byte getByteAt(long pos) { + return UNSAFE.getByte(pos); + } + long getLongAt(long pos, long[] array) { return UNSAFE.getLong(array, pos + sun.misc.Unsafe.ARRAY_LONG_BASE_OFFSET); }