diff --git a/calculate_average_roman-r-m.sh b/calculate_average_roman-r-m.sh index 47626a1..fe468dc 100755 --- a/calculate_average_roman-r-m.sh +++ b/calculate_average_roman-r-m.sh @@ -16,4 +16,10 @@ # JAVA_OPTS="--enable-preview -XX:+UseTransparentHugePages" + +# epsilon GC needs enough memory or it makes things worse +# see https://stackoverflow.com/questions/58087596/why-are-repeated-memory-allocations-observed-to-be-slower-using-epsilon-vs-g1 +# 2GB seems to be the sweet spot +JAVA_OPTS="$JAVA_OPTS -XX:+UnlockExperimentalVMOptions -XX:-EnableJVMCI -XX:+UseEpsilonGC -Xmx2G -Xms2G -XX:+AlwaysPreTouch" + java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_roman_r_m diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java index 5c43824..a7df56e 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java @@ -33,37 +33,35 @@ public class CalculateAverage_roman_r_m { private static Unsafe UNSAFE; - // based on http://0x80.pl/notesen/2023-03-06-swar-find-any.html - static long hasZeroByte(long l) { - return ((l - 0x0101010101010101L) & ~(l) & 0x8080808080808080L); - } - - static long firstSetByteIndex(long l) { - return ((((l - 1) & 0x101010101010101L) * 0x101010101010101L) >> 56) - 1; - } - - static long broadcast(byte b) { + private static long broadcast(byte b) { return 0x101010101010101L * b; } - static long SEMICOLON_MASK = broadcast((byte) ';'); - static long LINE_END_MASK = broadcast((byte) '\n'); + private static final long SEMICOLON_MASK = broadcast((byte) ';'); + private static final long LINE_END_MASK = broadcast((byte) '\n'); + private static final long DOT_MASK = broadcast((byte) '.'); - static long find(long l, long mask) { - long xor = l ^ mask; - long match = hasZeroByte(xor); - return match != 0 ? firstSetByteIndex(match) : -1; + // from netty + + /** + * Applies a compiled pattern to given word. + * Returns a word where each byte that matches the pattern has the highest bit set. + */ + private static long applyPattern(final long word, final long pattern) { + long input = word ^ pattern; + long tmp = (input & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL; + return ~(tmp | input | 0x7F7F7F7F7F7F7F7FL); } static long nextNewline(long from, MemorySegment ms) { long start = from; long i; long next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, start); - while ((i = find(next, LINE_END_MASK)) < 0) { + while ((i = applyPattern(next, LINE_END_MASK)) == 0) { start += 8; next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, start); } - return start + i; + return start + Long.numberOfTrailingZeros(i) / 8; } static class Worker { @@ -84,55 +82,53 @@ public class CalculateAverage_roman_r_m { private void parseName(ByteString station) { long start = offset; - long pos = -1; - - while (end - offset > 8) { - long next = UNSAFE.getLong(offset); - pos = find(next, SEMICOLON_MASK); - if (pos >= 0) { - offset += pos; - break; - } - else { - offset += 8; - } - } - if (pos < 0) { - while (UNSAFE.getByte(offset++) != ';') { - } - offset--; + long pattern; + long next = UNSAFE.getLong(offset); + while ((pattern = applyPattern(next, SEMICOLON_MASK)) == 0) { + offset += 8; + next = UNSAFE.getLong(offset); } + int bytes = Long.numberOfTrailingZeros(pattern) / 8; + offset += bytes; int len = (int) (offset - start); station.offset = start; station.len = len; station.hash = 0; + station.tail = next & ((1L << (8 * bytes)) - 1); offset++; } - long parseNumberFast() { + int parseNumberFast() { long encodedVal = UNSAFE.getLong(offset); - var len = find(encodedVal, LINE_END_MASK); - offset += len + 1; + int neg = 1 - Integer.bitCount((int) (encodedVal & 0x10)); + encodedVal >>>= 8 * neg; + + var len = applyPattern(encodedVal, DOT_MASK); + len = Long.numberOfTrailingZeros(len) / 8; encodedVal ^= broadcast((byte) 0x30); - long c0 = len == 4 ? 100 : 10; - long c1 = 10 * (len - 3); - long c2 = 4 - len; - long c3 = len - 3; - long a = (encodedVal & 0xFF) * c0; - long b = ((encodedVal & 0xFF00) >>> 8) * c1; - long c = ((encodedVal & 0xFF0000L) >>> 16) * c2; - long d = ((encodedVal & 0xFF000000L) >>> 24) * c3; + int intPart = (int) (encodedVal & ((1 << (8 * len)) - 1)); + intPart <<= 8 * (2 - len); + intPart *= (100 * 256 + 10); + intPart = (intPart & 0x3FF80) >>> 8; - return a + b + c + d; + int frac = (int) ((encodedVal >>> (8 * (len + 1))) & 0xFF); + + offset += neg + len + 3; // 1 for . + 1 for fractional part + 1 for new line char + int sign = 1 - 2 * neg; + int val = intPart + frac; + return sign * val; } - long parseNumberSlow() { - long val = UNSAFE.getByte(offset++) - '0'; + int parseNumberSlow() { + int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10); + offset += neg; + + int val = UNSAFE.getByte(offset++) - '0'; byte b; while ((b = UNSAFE.getByte(offset++)) != '.') { val = val * 10 + (b - '0'); @@ -140,22 +136,17 @@ public class CalculateAverage_roman_r_m { b = UNSAFE.getByte(offset); val = val * 10 + (b - '0'); offset += 2; + val *= 1 - 2 * neg; return val; } - long parseNumber() { - long val; - int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10); - offset += neg; - - if (end - offset > 8) { - val = parseNumberFast(); + int parseNumber() { + if (end - offset >= 8) { + return parseNumberFast(); } else { - val = parseNumberSlow(); + return parseNumberSlow(); } - val *= 1 - 2 * neg; - return val; } public TreeMap run() { @@ -218,6 +209,7 @@ public class CalculateAverage_roman_r_m { private long offset; private int len = 0; private int hash = 0; + private long tail = 0L; ByteString(MemorySegment ms) { this.ms = ms; @@ -235,6 +227,7 @@ public class CalculateAverage_roman_r_m { copy.offset = this.offset; copy.len = this.len; copy.hash = this.hash; + copy.tail = this.tail; return copy; } @@ -259,19 +252,7 @@ public class CalculateAverage_roman_r_m { return false; } } - if (len >= 8) { - long l1 = UNSAFE.getLong(offset + len - 8); - long l2 = UNSAFE.getLong(that.offset + len - 8); - return l1 == l2; - } - for (; i < len; i++) { - byte i1 = UNSAFE.getByte(offset + i); - byte i2 = UNSAFE.getByte(that.offset + i); - if (i1 != i2) { - return false; - } - } - return true; + return this.tail == that.tail; } @Override