From 271bdfb0329df636988455e450bb48a45f5b917f Mon Sep 17 00:00:00 2001 From: Van Phu DO Date: Fri, 26 Jan 2024 06:57:04 +0900 Subject: [PATCH] Simplify Node class with less field, improve hash mix speed (#584) * Simplify Node class with less field, improve hash mix speed * remove some ops, a bit faster * more inline, little bit faster but not sure --- prepare_abeobk.sh | 4 +- .../onebrc/CalculateAverage_abeobk.java | 126 +++++++++--------- 2 files changed, 63 insertions(+), 67 deletions(-) diff --git a/prepare_abeobk.sh b/prepare_abeobk.sh index d8ed86a..1b73743 100755 --- a/prepare_abeobk.sh +++ b/prepare_abeobk.sh @@ -20,6 +20,8 @@ sdk use java 21.0.2-graal 1>&2 # ./mvnw clean verify removes target/ and will re-trigger native image creation. if [ ! -f target/CalculateAverage_abeobk_image ]; then - NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -dsa -march=native -R:MaxHeapSize=128m -H:-GenLoopSafepoints -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk fi + + diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java index 293a88c..ed859f3 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java @@ -39,6 +39,7 @@ public class CalculateAverage_abeobk { private static final int BUCKET_SIZE = 1 << 16; private static final int BUCKET_MASK = BUCKET_SIZE - 1; private static final int MAX_STR_LEN = 100; + private static final int MAX_STATIONS = 10000; private static final Unsafe UNSAFE = initUnsafe(); private static final long[] HASH_MASKS = new long[]{ 0x0L, @@ -66,6 +67,33 @@ public class CalculateAverage_abeobk { } } + static class Stat { + Node node; + String key; + + public final String toString() { + return (node.min / 10.0) + "/" + + (Math.round(((double) node.sum / node.count)) / 10.0) + "/" + + (node.max / 10.0); + } + + Stat(Node n) { + node = n; + byte[] sbuf = new byte[MAX_STR_LEN]; + long word = UNSAFE.getLong(n.addr); + long semipos_code = getSemiPosCode(word); + int keylen = 0; + while (semipos_code == 0) { + keylen += 8; + word = UNSAFE.getLong(n.addr + keylen); + semipos_code = getSemiPosCode(word); + } + keylen += Long.numberOfTrailingZeros(semipos_code) >>> 3; + UNSAFE.copyMemory(null, n.addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen); + key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8); + } + } + static class Node { long addr; long word0; @@ -73,37 +101,23 @@ public class CalculateAverage_abeobk { long sum; int count; short min, max; - int keylen; - String key; - void calcKey() { - byte[] sbuf = new byte[MAX_STR_LEN]; - UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen); - key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8); - } - - public String toString() { - return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1); - } - - Node(long a, long t, short val, int kl) { + Node(long a, long t, short val) { addr = a; tail = t; - keylen = kl; sum = min = max = val; count = 1; } - Node(long a, long w0, long t, short val, int kl) { + Node(long a, long w0, long t, short val) { addr = a; word0 = w0; tail = t; - keylen = kl; sum = min = max = val; count = 1; } - void add(short val) { + final void add(short val) { sum += val; count++; if (val >= max) { @@ -115,7 +129,7 @@ public class CalculateAverage_abeobk { } } - void merge(Node other) { + final void merge(Node other) { sum += other.sum; count += other.count; if (other.max > max) { @@ -126,8 +140,8 @@ public class CalculateAverage_abeobk { } } - boolean contentEquals(long other_addr, long other_word0, long other_tail) { - if (tail != other_tail || word0 != other_word0) + final boolean contentEquals(long other_addr, long other_word0, long other_tail, int keylen) { + if (word0 != other_word0 || tail != other_tail) return false; // this is faster than comparision if key is short long xsum = 0; @@ -161,11 +175,8 @@ public class CalculateAverage_abeobk { // speed/collision balance static final int xxh32(long hash) { - final int p1 = 0x85EBCA77; // prime - int low = (int) hash; - int high = (int) (hash >>> 33); - int h = (low * p1) ^ high; - return h ^ (h >>> 17); + long h = hash * 37; + return (int) (h ^ (h >>> 29)); } // great idea from merykitty (Quan Anh Mai) @@ -185,11 +196,10 @@ public class CalculateAverage_abeobk { static final Node[] parse(int thread_id, long start, long end) { int cls = 0; long addr = start; - var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions + var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions // parse loop while (addr < end) { long row_addr = addr; - long hash = 0; long word0 = UNSAFE.getLong(addr); long semipos_code = getSemiPosCode(word0); @@ -202,14 +212,14 @@ public class CalculateAverage_abeobk { int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); addr += (dot_pos >>> 3) + 3; - long tail = (word0 & HASH_MASKS[semi_pos]); + long tail = word0 & HASH_MASKS[semi_pos]; int bucket = xxh32(tail) & BUCKET_MASK; short val = parseNum(num_word, dot_pos); while (true) { var node = map[bucket]; if (node == null) { - map[bucket] = new Node(row_addr, tail, val, semi_pos); + map[bucket] = new Node(row_addr, tail, val); break; } if (node.tail == tail) { @@ -223,28 +233,25 @@ public class CalculateAverage_abeobk { continue; } - hash ^= word0; addr += 8; long word = UNSAFE.getLong(addr); semipos_code = getSemiPosCode(word); // 43% chance if (semipos_code != 0) { int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos; - int keylen = (int) (addr - row_addr); - long num_word = UNSAFE.getLong(addr + 1); + addr += semi_pos + 1; + long num_word = UNSAFE.getLong(addr); int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); - addr += (dot_pos >>> 3) + 4; + addr += (dot_pos >>> 3) + 3; long tail = (word & HASH_MASKS[semi_pos]); - hash ^= tail; - int bucket = xxh32(hash) & BUCKET_MASK; + int bucket = xxh32(word0 ^ tail) & BUCKET_MASK; short val = parseNum(num_word, dot_pos); while (true) { var node = map[bucket]; if (node == null) { - map[bucket] = new Node(row_addr, word0, tail, val, keylen); + map[bucket] = new Node(row_addr, word0, tail, val); break; } if (node.word0 == word0 && node.tail == tail) { @@ -258,6 +265,9 @@ public class CalculateAverage_abeobk { continue; } + // why not going for more? tested, slower + + long hash = word0; while (semipos_code == 0) { hash ^= word; addr += 8; @@ -273,17 +283,16 @@ public class CalculateAverage_abeobk { addr += (dot_pos >>> 3) + 4; long tail = (word & HASH_MASKS[semi_pos]); - hash ^= tail; - int bucket = xxh32(hash) & BUCKET_MASK; + int bucket = xxh32(hash ^ tail) & BUCKET_MASK; short val = parseNum(num_word, dot_pos); while (true) { var node = map[bucket]; if (node == null) { - map[bucket] = new Node(row_addr, word0, tail, val, keylen); + map[bucket] = new Node(row_addr, word0, tail, val); break; } - if (node.contentEquals(row_addr, word0, tail)) { + if (node.contentEquals(row_addr, word0, tail, keylen)) { node.add(val); break; } @@ -292,6 +301,7 @@ public class CalculateAverage_abeobk { cls++; } } + if (SHOW_ANALYSIS) { debug("Thread %d collision = %d", thread_id, cls); } @@ -307,8 +317,6 @@ public class CalculateAverage_abeobk { workerCommand.add("--worker"); new ProcessBuilder() .command(workerCommand) - .inheritIO() - .redirectOutput(ProcessBuilder.Redirect.PIPE) .start() .getInputStream() .transferTo(System.out); @@ -333,43 +341,29 @@ public class CalculateAverage_abeobk { // processing var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); - TreeMap ms = new TreeMap<>(); - int[] lenhist = new int[64]; // length histogram - - List> maps = IntStream.range(0, cpu_cnt) + List> maps = IntStream.range(0, cpu_cnt) .mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1])) .map(map -> { - List nodes = new ArrayList<>(); + List stats = new ArrayList<>(); for (var node : map) { if (node == null) continue; - node.calcKey(); - nodes.add(node); + stats.add(new Stat(node)); } - return nodes; + return stats; }) .parallel() .toList(); - for (var nodes : maps) { - for (var node : nodes) { - if (SHOW_ANALYSIS) { - int kl = node.keylen & (lenhist.length - 1); - lenhist[kl] += node.count; - } - var stat = ms.putIfAbsent(node.key, node); + TreeMap ms = new TreeMap<>(); + for (var stats : maps) { + for (var s : stats) { + var stat = ms.putIfAbsent(s.key, s); if (stat != null) - stat.merge(node); + stat.node.merge(s.node); } } - if (SHOW_ANALYSIS) { - debug("Total = " + Arrays.stream(lenhist).sum()); - debug("Length_histogram = " - + Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray())); - return; - } - // print result System.out.println(ms); System.out.close();