From 98a8279669d0483b59cc40b8809e654758b5ad54 Mon Sep 17 00:00:00 2001 From: Van Phu DO Date: Wed, 24 Jan 2024 00:41:25 +0900 Subject: [PATCH] use thomaswue trick, use parallelism, slightly faster (#560) --- prepare_abeobk.sh | 4 +- .../onebrc/CalculateAverage_abeobk.java | 155 +++++++++++------- 2 files changed, 96 insertions(+), 63 deletions(-) diff --git a/prepare_abeobk.sh b/prepare_abeobk.sh index fac7b87..d8ed86a 100755 --- a/prepare_abeobk.sh +++ b/prepare_abeobk.sh @@ -16,10 +16,10 @@ # source "$HOME/.sdkman/bin/sdkman-init.sh" -sdk use java 21.0.1-graal 1>&2 +sdk use java 21.0.2-graal 1>&2 # ./mvnw clean verify removes target/ and will re-trigger native image creation. if [ ! -f target/CalculateAverage_abeobk_image ]; then - NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java index 48d9da6..293a88c 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java @@ -24,8 +24,12 @@ import java.nio.channels.FileChannel.MapMode; import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.TreeMap; +import java.util.stream.IntStream; + import sun.misc.Unsafe; public class CalculateAverage_abeobk { @@ -66,22 +70,23 @@ public class CalculateAverage_abeobk { long addr; long word0; long tail; - int keylen; - int min, max; - int count; long sum; + int count; + short min, max; + int keylen; + String key; - String key() { + void calcKey() { byte[] sbuf = new byte[MAX_STR_LEN]; UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen); - return new String(sbuf, 0, keylen, StandardCharsets.UTF_8); + key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8); } public String toString() { return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1); } - Node(long a, long t, int val, int kl) { + Node(long a, long t, short val, int kl) { addr = a; tail = t; keylen = kl; @@ -89,12 +94,16 @@ public class CalculateAverage_abeobk { count = 1; } - Node(long a, long t, int val, int kl, long w0) { - this(a, t, val, kl); + Node(long a, long w0, long t, short val, int kl) { + addr = a; word0 = w0; + tail = t; + keylen = kl; + sum = min = max = val; + count = 1; } - void add(int val) { + void add(short val) { sum += val; count++; if (val >= max) { @@ -107,19 +116,23 @@ public class CalculateAverage_abeobk { } void merge(Node other) { - min = Math.min(min, other.min); - max = Math.max(max, other.max); sum += other.sum; count += other.count; + if (other.max > max) { + max = other.max; + } + if (other.min < min) { + min = other.min; + } } - boolean contentEquals(long other_addr, long other_tail) { - if (tail != other_tail) + boolean contentEquals(long other_addr, long other_word0, long other_tail) { + if (tail != other_tail || word0 != other_word0) return false; // this is faster than comparision if key is short long xsum = 0; int n = keylen & 0xF8; - for (int i = 0; i < n; i += 8) { + for (int i = 8; i < n; i += 8) { xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i)); } return xsum == 0; @@ -156,29 +169,27 @@ public class CalculateAverage_abeobk { } // great idea from merykitty (Quan Anh Mai) - static final int parseNum(long num_word, int dot_pos) { + static final short parseNum(long num_word, int dot_pos) { int shift = 28 - dot_pos; long signed = (~num_word << 59) >> 63; long dsmask = ~(signed & 0xFF); long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; - return (int) ((abs_val ^ signed) - signed); + return (short) ((abs_val ^ signed) - signed); } // optimize for contest // save as much slow memory access as possible // about 50% key < 8chars, 25% key bettween 8-10 chars // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... - static final Node[] parse(int thread_id, long start, long end, int[] cls) { + static final Node[] parse(int thread_id, long start, long end) { + int cls = 0; long addr = start; var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions // parse loop while (addr < end) { long row_addr = addr; - long tail = 0; long hash = 0; - int val = 0; - int bucket = 0; long word0 = UNSAFE.getLong(addr); long semipos_code = getSemiPosCode(word0); @@ -191,9 +202,9 @@ public class CalculateAverage_abeobk { int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); addr += (dot_pos >>> 3) + 3; - tail = (word0 & HASH_MASKS[semi_pos]); - bucket = xxh32(tail) & BUCKET_MASK; - val = parseNum(num_word, dot_pos); + long tail = (word0 & HASH_MASKS[semi_pos]); + int bucket = xxh32(tail) & BUCKET_MASK; + short val = parseNum(num_word, dot_pos); while (true) { var node = map[bucket]; @@ -207,7 +218,7 @@ public class CalculateAverage_abeobk { } bucket++; if (SHOW_ANALYSIS) - cls[thread_id]++; + cls++; } continue; } @@ -225,15 +236,15 @@ public class CalculateAverage_abeobk { int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); addr += (dot_pos >>> 3) + 4; - tail = (word & HASH_MASKS[semi_pos]); + long tail = (word & HASH_MASKS[semi_pos]); hash ^= tail; - bucket = xxh32(hash) & BUCKET_MASK; - val = parseNum(num_word, dot_pos); + int bucket = xxh32(hash) & BUCKET_MASK; + short val = parseNum(num_word, dot_pos); while (true) { var node = map[bucket]; if (node == null) { - map[bucket] = new Node(row_addr, tail, val, keylen, word0); + map[bucket] = new Node(row_addr, word0, tail, val, keylen); break; } if (node.word0 == word0 && node.tail == tail) { @@ -242,7 +253,7 @@ public class CalculateAverage_abeobk { } bucket++; if (SHOW_ANALYSIS) - cls[thread_id]++; + cls++; } continue; } @@ -261,30 +272,55 @@ public class CalculateAverage_abeobk { int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); addr += (dot_pos >>> 3) + 4; - tail = (word & HASH_MASKS[semi_pos]); + long tail = (word & HASH_MASKS[semi_pos]); hash ^= tail; - bucket = xxh32(hash) & BUCKET_MASK; - val = parseNum(num_word, dot_pos); + int bucket = xxh32(hash) & BUCKET_MASK; + short val = parseNum(num_word, dot_pos); while (true) { var node = map[bucket]; if (node == null) { - map[bucket] = new Node(row_addr, tail, val, keylen); + map[bucket] = new Node(row_addr, word0, tail, val, keylen); break; } - if (node.contentEquals(row_addr, tail)) { + if (node.contentEquals(row_addr, word0, tail)) { node.add(val); break; } bucket++; if (SHOW_ANALYSIS) - cls[thread_id]++; + cls++; } } + if (SHOW_ANALYSIS) { + debug("Thread %d collision = %d", thread_id, cls); + } return map; } + // thomaswue trick + private static void spawnWorker() throws IOException { + ProcessHandle.Info info = ProcessHandle.current().info(); + ArrayList workerCommand = new ArrayList<>(); + info.command().ifPresent(workerCommand::add); + info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args))); + workerCommand.add("--worker"); + new ProcessBuilder() + .command(workerCommand) + .inheritIO() + .redirectOutput(ProcessBuilder.Redirect.PIPE) + .start() + .getInputStream() + .transferTo(System.out); + } + public static void main(String[] args) throws InterruptedException, IOException { + // thomaswue trick + if (args.length == 0 || !("--worker".equals(args[0]))) { + spawnWorker(); + return; + } + try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); long file_size = file.size(); @@ -295,51 +331,48 @@ public class CalculateAverage_abeobk { long chunk_size = Math.ceilDiv(file_size, cpu_cnt); // processing - var threads = new Thread[cpu_cnt]; - var maps = new Node[cpu_cnt][]; var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); - int[] cls = new int[cpu_cnt]; // collision + TreeMap ms = new TreeMap<>(); int[] lenhist = new int[64]; // length histogram - for (int i = 0; i < cpu_cnt; i++) { - int thread_id = i; - (threads[thread_id] = new Thread(() -> { - maps[thread_id] = parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1], cls); - })).start(); - } + List> maps = IntStream.range(0, cpu_cnt) + .mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1])) + .map(map -> { + List nodes = new ArrayList<>(); + for (var node : map) { + if (node == null) + continue; + node.calcKey(); + nodes.add(node); + } + return nodes; + }) + .parallel() + .toList(); - // join all - for (var thread : threads) - thread.join(); - - // collect results - TreeMap ms = new TreeMap<>(); - for (var map : maps) { - for (var node : map) { - if (node == null) - continue; + for (var nodes : maps) { + for (var node : nodes) { if (SHOW_ANALYSIS) { int kl = node.keylen & (lenhist.length - 1); lenhist[kl] += node.count; } - var stat = ms.putIfAbsent(node.key(), node); + var stat = ms.putIfAbsent(node.key, node); if (stat != null) stat.merge(node); } } if (SHOW_ANALYSIS) { - debug("Collision stat: "); - for (int i = 0; i < cpu_cnt; i++) { - debug("thread-" + i + " collision = " + cls[i]); - } debug("Total = " + Arrays.stream(lenhist).sum()); debug("Length_histogram = " + Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray())); + return; } - else - System.out.println(ms); + + // print result + System.out.println(ms); + System.out.close(); } } } \ No newline at end of file