From 6c0949969a5df71df85397ce3338f0154e6727ac Mon Sep 17 00:00:00 2001 From: Roman Musin <995612+roman-r-m@users.noreply.github.com> Date: Tue, 23 Jan 2024 19:19:07 +0000 Subject: [PATCH] Native image + a few smaller optimisations (#564) * Inline parsing name and station to avoid constantly updating the offset field (-100ms) * Remove Worker class, inline the logic into lambda * Accumulate results in an int matrix instead of using result row (-50ms) * Use native image --- calculate_average_roman-r-m.sh | 13 +- prepare_roman-r-m.sh | 9 + .../onebrc/CalculateAverage_roman_r_m.java | 247 ++++++++---------- 3 files changed, 135 insertions(+), 134 deletions(-) diff --git a/calculate_average_roman-r-m.sh b/calculate_average_roman-r-m.sh index b5d0b3d..acf9864 100755 --- a/calculate_average_roman-r-m.sh +++ b/calculate_average_roman-r-m.sh @@ -21,4 +21,15 @@ JAVA_OPTS="--enable-preview -XX:+UseTransparentHugePages" # see https://stackoverflow.com/questions/58087596/why-are-repeated-memory-allocations-observed-to-be-slower-using-epsilon-vs-g1 JAVA_OPTS="$JAVA_OPTS -XX:+UnlockExperimentalVMOptions -XX:-EnableJVMCI -XX:+UseEpsilonGC -Xmx1G -Xms1G -XX:+AlwaysPreTouch" -java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_roman_r_m +if [ -f target/CalculateAverage_roman_r_m_image ]; then + echo "Picking up existing native image 'target/CalculateAverage_roman_r_m_image', delete the file to select JVM mode." 1>&2 + target/CalculateAverage_roman_r_m_image +else + JAVA_OPTS="--enable-preview -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA" + if [[ ! "$(uname -s)" = "Darwin" ]]; then + # On OS/X, my machine, this errors: + JAVA_OPTS="$JAVA_OPTS -XX:+UseTransparentHugePages" + fi + echo "Choosing to run the app in JVM mode as no native image was found, use additional_build_step_roman_r_m.sh to generate." 1>&2 + java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_roman_r_m +fi diff --git a/prepare_roman-r-m.sh b/prepare_roman-r-m.sh index f83a3ff..a0593b2 100755 --- a/prepare_roman-r-m.sh +++ b/prepare_roman-r-m.sh @@ -17,3 +17,12 @@ source "$HOME/.sdkman/bin/sdkman-init.sh" sdk use java 21.0.1-graal 1>&2 + +# ./mvnw clean verify removes target/ and will re-trigger native image creation. +if [ ! -f target/CalculateAverage_roman_r_m_image ]; then + + JAVA_OPTS="--enable-preview -dsa" + NATIVE_IMAGE_OPTS="--initialize-at-build-time=dev.morling.onebrc.CalculateAverage_roman_r_m --gc=epsilon -Ob -O3 -march=native --strict-image-heap $JAVA_OPTS" + + native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_roman_r_m_image dev.morling.onebrc.CalculateAverage_roman_r_m +fi \ No newline at end of file diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java index 1a43ae5..896616d 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java @@ -64,119 +64,6 @@ public class CalculateAverage_roman_r_m { return start + Long.numberOfTrailingZeros(i) / 8; } - static class Worker { - private final MemorySegment ms; - private final long end; - private long offset; - - public Worker(FileChannel channel, long start, long end) { - try { - this.ms = channel.map(FileChannel.MapMode.READ_ONLY, start, end - start, Arena.ofConfined()); - this.offset = ms.address(); - this.end = ms.address() + end - start; - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - - private void parseName(ByteString station) { - long start = offset; - long next = UNSAFE.getLong(offset); - long pattern = applyPattern(next, SEMICOLON_MASK); - int bytes; - if (pattern != 0) { - bytes = Long.numberOfTrailingZeros(pattern) / 8; - offset += bytes; - long h = Long.reverseBytes(next) >>> (8 * (8 - bytes)); - station.hash = (int) (h ^ (h >>> 32)); - } - else { - long h = next; - station.hash = (int) (h ^ (h >>> 32)); - while (pattern == 0) { - offset += 8; - next = UNSAFE.getLong(offset); - pattern = applyPattern(next, SEMICOLON_MASK); - } - bytes = Long.numberOfTrailingZeros(pattern) / 8; - offset += bytes; - } - - int len = (int) (offset - start); - station.offset = start; - station.len = len; - station.tail = next & ((1L << (8 * bytes)) - 1); - - offset++; - } - - int parseNumberFast() { - long encodedVal = UNSAFE.getLong(offset); - - int neg = 1 - Integer.bitCount((int) (encodedVal & 0x10)); - encodedVal >>>= 8 * neg; - - var len = applyPattern(encodedVal, DOT_MASK); - len = Long.numberOfTrailingZeros(len) / 8; - - encodedVal ^= broadcast((byte) 0x30); - - int intPart = (int) (encodedVal & ((1 << (8 * len)) - 1)); - intPart <<= 8 * (2 - len); - intPart *= (100 * 256 + 10); - intPart = (intPart & 0x3FF80) >>> 8; - - int frac = (int) ((encodedVal >>> (8 * (len + 1))) & 0xFF); - - offset += neg + len + 3; // 1 for . + 1 for fractional part + 1 for new line char - int sign = 1 - 2 * neg; - int val = intPart + frac; - return sign * val; - } - - int parseNumberSlow() { - int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10); - offset += neg; - - int val = UNSAFE.getByte(offset++) - '0'; - byte b; - while ((b = UNSAFE.getByte(offset++)) != '.') { - val = val * 10 + (b - '0'); - } - b = UNSAFE.getByte(offset); - val = val * 10 + (b - '0'); - offset += 2; - val *= 1 - 2 * neg; - return val; - } - - int parseNumber() { - if (end - offset >= 8) { - return parseNumberFast(); - } - else { - return parseNumberSlow(); - } - } - - public TreeMap run() { - var resultStore = new ResultStore(); - var station = new ByteString(ms); - - while (offset < end) { - parseName(station); - long val = parseNumber(); - var a = resultStore.get(station); - a.min = Math.min(a.min, val); - a.max = Math.max(a.max, val); - a.sum += val; - a.count++; - } - return resultStore.toMap(); - } - } - public static void main(String[] args) throws Exception { Field f = Unsafe.class.getDeclaredField("theUnsafe"); f.setAccessible(true); @@ -200,12 +87,97 @@ public class CalculateAverage_roman_r_m { var result = IntStream.range(0, numThreads) .parallel() .mapToObj(i -> { - long start = i == 0 ? 0 : bounds[i - 1] + 1; - long end = bounds[i]; - Worker worker = new Worker(channel, start, end); - var res = worker.run(); - worker.ms.unload(); - return res; + try { + long segmentStart = i == 0 ? 0 : bounds[i - 1] + 1; + long segmentEnd = bounds[i]; + var segment = channel.map(FileChannel.MapMode.READ_ONLY, segmentStart, segmentEnd - segmentStart, Arena.ofConfined()); + + var resultStore = new ResultStore(); + var station = new ByteString(segment); + long offset = segment.address(); + long end = offset + segment.byteSize(); + while (offset < end) { + // parsing station name + long start = offset; + long next = UNSAFE.getLong(offset); + long pattern = applyPattern(next, SEMICOLON_MASK); + int bytes; + if (pattern != 0) { + bytes = Long.numberOfTrailingZeros(pattern) / 8; + offset += bytes; + long h = Long.reverseBytes(next) >>> (8 * (8 - bytes)); + station.hash = (int) (h ^ (h >>> 32)); + } + else { + long h = next; + station.hash = (int) (h ^ (h >>> 32)); + while (pattern == 0) { + offset += 8; + next = UNSAFE.getLong(offset); + pattern = applyPattern(next, SEMICOLON_MASK); + } + bytes = Long.numberOfTrailingZeros(pattern) / 8; + offset += bytes; + } + + int len = (int) (offset - start); + station.offset = start; + station.len = len; + station.tail = next & ((1L << (8 * bytes)) - 1); + + offset++; + + // parsing temperature + // TODO next may contain temperature as well, maybe try using it if we know the full number is there + // 8 - bytes >= 5 -> bytes <= 3 + long val; + if (end - offset >= 8) { + long encodedVal = UNSAFE.getLong(offset); + + int neg = 1 - Integer.bitCount((int) (encodedVal & 0x10)); + encodedVal >>>= 8 * neg; + + long numLen = applyPattern(encodedVal, DOT_MASK); + numLen = Long.numberOfTrailingZeros(numLen) / 8; + + encodedVal ^= broadcast((byte) 0x30); + + int intPart = (int) (encodedVal & ((1 << (8 * numLen)) - 1)); + intPart <<= 8 * (2 - numLen); + intPart *= (100 * 256 + 10); + intPart = (intPart & 0x3FF80) >>> 8; + + int frac = (int) ((encodedVal >>> (8 * (numLen + 1))) & 0xFF); + + offset += neg + numLen + 3; // 1 for . + 1 for fractional part + 1 for new line char + int sign = 1 - 2 * neg; + val = sign * (intPart + frac); + } + else { + int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10); + offset += neg; + + val = UNSAFE.getByte(offset++) - '0'; + byte b; + while ((b = UNSAFE.getByte(offset++)) != '.') { + val = val * 10 + (b - '0'); + } + b = UNSAFE.getByte(offset); + val = val * 10 + (b - '0'); + offset += 2; + val *= 1 - (2L * neg); + } + + resultStore.update(station, (int) val); + } + + segment.unload(); + + return resultStore.toMap(); + } + catch (Exception e) { + throw new RuntimeException(e); + } }).reduce((m1, m2) -> { m2.forEach((k, v) -> m1.merge(k, v, ResultRow::merge)); return m1; @@ -275,10 +247,17 @@ public class CalculateAverage_roman_r_m { } private static final class ResultRow { - long min = 1000; - long sum = 0; - long max = -1000; - int count = 0; + long min; + long sum; + long max; + int count; + + public ResultRow(int[] values) { + min = values[0]; + max = values[1]; + sum = values[2]; + count = values[3]; + } public String toString() { return round(min / 10.0) + "/" + round(sum / 10.0 / count) + "/" + round(max / 10.0); @@ -300,9 +279,9 @@ public class CalculateAverage_roman_r_m { static class ResultStore { private static final int SIZE = 16384; private final ByteString[] keys = new ByteString[SIZE]; - private final ResultRow[] values = new ResultRow[SIZE]; + private final int[][] values = new int[SIZE][]; - ResultRow get(ByteString s) { + void update(ByteString s, int value) { int h = s.hashCode(); int idx = (SIZE - 1) & h; @@ -311,18 +290,20 @@ public class CalculateAverage_roman_r_m { i++; idx = (idx + i * i) % SIZE; } - ResultRow result; if (keys[idx] == null) { keys[idx] = s.copy(); - result = new ResultRow(); - values[idx] = result; + values[idx] = new int[4]; + values[idx][0] = value; + values[idx][1] = value; + values[idx][2] = value; + values[idx][3] = 1; } else { - result = values[idx]; - // TODO see it it makes any difference - // keys[idx].offset = s.offset; + values[idx][0] = Math.min(values[idx][0], value); + values[idx][1] = Math.max(values[idx][1], value); + values[idx][2] += value; + values[idx][3] += 1; } - return result; } TreeMap toMap() { @@ -330,7 +311,7 @@ public class CalculateAverage_roman_r_m { var result = new TreeMap(); for (int i = 0; i < SIZE; i++) { if (keys[i] != null) { - result.put(keys[i].asString(buf), values[i]); + result.put(keys[i].asString(buf), new ResultRow(values[i])); } } return result;