From d0a28599c293d3afe3291fc3cf169a7b25ae9ae6 Mon Sep 17 00:00:00 2001 From: Thomas Wuerthinger Date: Sun, 21 Jan 2024 20:13:48 +0100 Subject: [PATCH] Tuning and subprocess spawn for thomaswue (#533) * Some clean up, small-scale tuning, and reduce complexity when handling longer names. * Do actual work in worker subprocess. Main process returns immediately and OS clean up of the mmap continues in the subprocess. * Update minor Graal version after CPU release. * Turn GC back to epsilon GC (although it does not seem to make a difference). * Minor tuning for another +1%. --- prepare_thomaswue.sh | 4 +- .../onebrc/CalculateAverage_thomaswue.java | 167 ++++++++++-------- 2 files changed, 99 insertions(+), 72 deletions(-) diff --git a/prepare_thomaswue.sh b/prepare_thomaswue.sh index 1c6be64..32616a9 100755 --- a/prepare_thomaswue.sh +++ b/prepare_thomaswue.sh @@ -16,11 +16,11 @@ # source "$HOME/.sdkman/bin/sdkman-init.sh" -sdk use java 21.0.1-graal 1>&2 +sdk use java 21.0.2-graal 1>&2 # ./mvnw clean verify removes target/ and will re-trigger native image creation. if [ ! -f target/CalculateAverage_thomaswue_image ]; then - NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview" + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_thomaswue\$Scanner" # Use -H:MethodFilter=CalculateAverage_thomaswue.* -H:Dump=:2 -H:PrintGraph=Network for IdealGraphVisualizer graph dumping. native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_thomaswue_image dev.morling.onebrc.CalculateAverage_thomaswue fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java index 041c17c..406c85d 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java @@ -15,13 +15,10 @@ */ package dev.morling.onebrc; -import sun.misc.Unsafe; - import java.io.IOException; -import java.lang.foreign.Arena; -import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.channels.FileChannel; -import java.nio.channels.FileChannel.MapMode; import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; @@ -39,18 +36,20 @@ import java.util.stream.IntStream; */ public class CalculateAverage_thomaswue { private static final String FILE = "./measurements.txt"; + private static final int MIN_TEMP = -999; + private static final int MAX_TEMP = 999; // Holding the current result for a single city. private static class Result { - long lastNameLong, secondLastNameLong, nameAddress; - int nameLength, remainingShift; - int min, max, count; + long lastNameLong, secondLastNameLong; + long[] name; + int count; + short min, max; long sum; - private Result(long nameAddress) { - this.nameAddress = nameAddress; - this.min = Integer.MAX_VALUE; - this.max = Integer.MIN_VALUE; + private Result() { + this.min = MAX_TEMP; + this.max = MIN_TEMP; } public String toString() { @@ -63,18 +62,32 @@ public class CalculateAverage_thomaswue { // Accumulate another result into this one. private void add(Result other) { - min = Math.min(min, other.min); - max = Math.max(max, other.max); + if (other.min < min) { + min = other.min; + } + if (other.max > max) { + max = other.max; + } sum += other.sum; count += other.count; } public String calcName() { - return new Scanner(nameAddress, nameAddress + nameLength).getString(nameLength); + ByteBuffer bb = ByteBuffer.allocate(name.length * Long.BYTES).order(ByteOrder.nativeOrder()); + bb.asLongBuffer().put(name); + byte[] array = bb.array(); + int i = 0; + while (array[i++] != ';') + ; + return new String(array, 0, i - 1, StandardCharsets.UTF_8); } } public static void main(String[] args) throws IOException { + if (args.length == 0 || !("--worker".equals(args[0]))) { + spawnWorker(); + return; + } // Calculate input segments. int numberOfChunks = Runtime.getRuntime().availableProcessors(); long[] chunks = getSegments(numberOfChunks); @@ -93,6 +106,22 @@ public class CalculateAverage_thomaswue { // Final output. System.out.println(accumulateResults(allResults)); + System.out.close(); + } + + private static void spawnWorker() throws IOException { + ProcessHandle.Info info = ProcessHandle.current().info(); + ArrayList workerCommand = new ArrayList<>(); + info.command().ifPresent(workerCommand::add); + info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args))); + workerCommand.add("--worker"); + new ProcessBuilder() + .command(workerCommand) + .inheritIO() + .redirectOutput(ProcessBuilder.Redirect.PIPE) + .start() + .getInputStream() + .transferTo(System.out); } // Accumulate results sequentially for simplicity. @@ -115,20 +144,21 @@ public class CalculateAverage_thomaswue { Result[] results = new Result[1 << 17]; Scanner scanner = new Scanner(chunkStart, chunkEnd); long word = scanner.getLong(); - int pos = findDelimiter(word); + long pos = findDelimiter(word); while (scanner.hasNext()) { long nameAddress = scanner.pos(); long hash = 0; // Search for ';', one long at a time. - if (pos != 8) { + if (pos != 0) { + pos = Long.numberOfTrailingZeros(pos) >>> 3; scanner.add(pos); word = mask(word, pos); hash = word; int number = scanNumber(scanner); long nextWord = scanner.getLong(); - int nextPos = findDelimiter(nextWord); + long nextPos = findDelimiter(nextWord); Result existingResult = results[hashToIndex(hash, results)]; if (existingResult != null && existingResult.lastNameLong == word) { @@ -142,11 +172,12 @@ public class CalculateAverage_thomaswue { } else { scanner.add(8); - hash ^= word; + hash = word; long prevWord = word; word = scanner.getLong(); pos = findDelimiter(word); - if (pos != 8) { + if (pos != 0) { + pos = Long.numberOfTrailingZeros(pos) >>> 3; scanner.add(pos); word = mask(word, pos); hash ^= word; @@ -166,7 +197,8 @@ public class CalculateAverage_thomaswue { while (true) { word = scanner.getLong(); pos = findDelimiter(word); - if (pos != 8) { + if (pos != 0) { + pos = Long.numberOfTrailingZeros(pos) >>> 3; scanner.add(pos); word = mask(word, pos); hash ^= word; @@ -182,12 +214,7 @@ public class CalculateAverage_thomaswue { // Save length of name for later. int nameLength = (int) (scanner.pos() - nameAddress); - scanner.add(1); - - long numberWord = scanner.getLong(); - int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); - int number = convertIntoNumber(decimalSepPos, numberWord); - scanner.add((decimalSepPos >>> 3) + 3); + int number = scanNumber(scanner); // Final calculation for index into hash table. int tableIndex = hashToIndex(hash, results); @@ -198,13 +225,16 @@ public class CalculateAverage_thomaswue { } // Check for collision. int i = 0; + int namePos = 0; for (; i < nameLength + 1 - 8; i += 8) { - if (scanner.getLongAt(existingResult.nameAddress + i) != scanner.getLongAt(nameAddress + i)) { + if (namePos >= existingResult.name.length || existingResult.name[namePos++] != scanner.getLongAt(nameAddress + i)) { tableIndex = (tableIndex + 31) & (results.length - 1); continue outer; } } - if (((existingResult.lastNameLong ^ scanner.getLongAt(nameAddress + i)) << existingResult.remainingShift) == 0) { + + int remainingShift = (64 - (nameLength + 1 - i) << 3); + if (((existingResult.lastNameLong ^ (scanner.getLongAt(nameAddress + i) << remainingShift)) == 0)) { record(existingResult, number); break; } @@ -230,63 +260,67 @@ public class CalculateAverage_thomaswue { } private static void record(Result existingResult, int number) { - existingResult.min = Math.min(existingResult.min, number); - existingResult.max = Math.max(existingResult.max, number); + if (number < existingResult.min) { + existingResult.min = (short) number; + } + if (number > existingResult.max) { + existingResult.max = (short) number; + } existingResult.sum += number; existingResult.count++; } private static int hashToIndex(long hash, Result[] results) { int hashAsInt = (int) (hash ^ (hash >>> 28)); - int finalHash = (hashAsInt ^ (hashAsInt >>> 15)); + int finalHash = (hashAsInt ^ (hashAsInt >>> 17)); return (finalHash & (results.length - 1)); } - private static long mask(long word, int pos) { - return word & (-1L >>> ((8 - pos - 1) << 3)); + private static long mask(long word, long pos) { + return (word << ((7 - pos) << 3)); } - // Special method to convert a number in the specific format into an int value without branches created by - // Quan Anh Mai. + // Special method to convert a number in the ascii number into an int without branches created by Quan Anh Mai. private static int convertIntoNumber(int decimalSepPos, long numberWord) { int shift = 28 - decimalSepPos; // signed is -1 if negative, 0 otherwise long signed = (~numberWord << 59) >> 63; long designMask = ~(signed & 0xFF); - // Align the number to a specific position and transform the ascii code - // to actual digit value in each byte + // Align the number to a specific position and transform the ascii to digit value long digits = ((numberWord & designMask) << shift) & 0x0F000F0F00L; - // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit) // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) = - // 0x000000UU00TTHH00 + - // 0x00UU00TTHH000000 * 10 + - // 0xUU00TTHH00000000 * 100 - // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400 - // This results in our value lies in the bit 32 to 41 of this product - // That was close :) + // 0x000000UU00TTHH00 + 0x00UU00TTHH000000 * 10 + 0xUU00TTHH00000000 * 100 long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; long value = (absValue ^ signed) - signed; return (int) value; } - private static int findDelimiter(long word) { + private static long findDelimiter(long word) { long input = word ^ 0x3B3B3B3B3B3B3B3BL; long tmp = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; - return Long.numberOfTrailingZeros(tmp) >>> 3; + return tmp; } private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner) { - Result r = new Result(nameAddress); + Result r = new Result(); results[hash] = r; - + long[] name = new long[(nameLength / Long.BYTES) + 1]; + int pos = 0; int i = 0; - for (; i < nameLength + 1 - 8; i += 8) { - r.secondLastNameLong = (scanner.getLongAt(nameAddress + i)); + for (; i < nameLength + 1 - Long.BYTES; i += Long.BYTES) { + name[pos++] = scanner.getLongAt(nameAddress + i); } - r.remainingShift = (64 - (nameLength + 1 - i) << 3); - r.lastNameLong = (scanner.getLongAt(nameAddress + i) & (-1L >>> r.remainingShift)); - r.nameLength = nameLength; + + if (pos > 0) { + r.secondLastNameLong = name[pos - 1]; + } + + int remainingShift = (64 - (nameLength + 1 - i) << 3); + long lastWord = (scanner.getLongAt(nameAddress + i) << remainingShift); + r.lastNameLong = lastWord; + name[pos] = lastWord >> remainingShift; + r.name = name; return r; } @@ -295,16 +329,15 @@ public class CalculateAverage_thomaswue { long fileSize = fileChannel.size(); long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks; long[] chunks = new long[numberOfChunks + 1]; - long mappedAddress = fileChannel.map(MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); + long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, java.lang.foreign.Arena.global()).address(); chunks[0] = mappedAddress; long endAddress = mappedAddress + fileSize; Scanner s = new Scanner(mappedAddress, mappedAddress + fileSize); for (int i = 1; i < numberOfChunks; ++i) { long chunkAddress = mappedAddress + i * segmentSize; // Align to first row start. - while (chunkAddress < endAddress && (s.getLongAt(chunkAddress++) & 0xFF) != '\n') { - // nop - } + while (chunkAddress < endAddress && (s.getLongAt(chunkAddress++) & 0xFF) != '\n') + ; chunks[i] = Math.min(chunkAddress, endAddress); } chunks[numberOfChunks] = endAddress; @@ -314,13 +347,13 @@ public class CalculateAverage_thomaswue { private static class Scanner { - private static final Unsafe UNSAFE = initUnsafe(); + private static final sun.misc.Unsafe UNSAFE = initUnsafe(); - private static Unsafe initUnsafe() { + private static sun.misc.Unsafe initUnsafe() { try { - Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + java.lang.reflect.Field theUnsafe = sun.misc.Unsafe.class.getDeclaredField("theUnsafe"); theUnsafe.setAccessible(true); - return (Unsafe) theUnsafe.get(Unsafe.class); + return (sun.misc.Unsafe) theUnsafe.get(sun.misc.Unsafe.class); } catch (NoSuchFieldException | IllegalAccessException e) { throw new RuntimeException(e); @@ -342,7 +375,7 @@ public class CalculateAverage_thomaswue { return pos; } - void add(int delta) { + void add(long delta) { pos += delta; } @@ -354,13 +387,7 @@ public class CalculateAverage_thomaswue { return UNSAFE.getLong(pos); } - public String getString(int nameLength) { - byte[] bytes = new byte[nameLength]; - UNSAFE.copyMemory(null, pos, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); - return new String(bytes, StandardCharsets.UTF_8); - } - - public void setPos(long l) { + void setPos(long l) { this.pos = l; } }