From 677d94e5cf9769d02843edf06686abfea7112d1d Mon Sep 17 00:00:00 2001 From: Van Phu DO Date: Tue, 16 Jan 2024 02:53:31 +0900 Subject: [PATCH] Optimized with less constructor args + low collision mixer (#420) * use all CPUs * use graal * optimized with less constructor arg * optimized with low collision mixer --- .../onebrc/CalculateAverage_abeobk.java | 175 ++++++++++-------- 1 file changed, 100 insertions(+), 75 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java index 1a71349..34a5552 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java @@ -28,6 +28,8 @@ import java.util.TreeMap; import sun.misc.Unsafe; public class CalculateAverage_abeobk { + private static final boolean SHOW_COLLISIONS = false; + private static final String FILE = "./measurements.txt"; private static final int BUCKET_SIZE = 1 << 16; private static final int BUCKET_MASK = BUCKET_SIZE - 1; @@ -55,69 +57,55 @@ public class CalculateAverage_abeobk { } } - // stat - private static class Stat { - private int min; - private int max; - private long sum; - private int count; + static class Node { + long addr; + long tail; + int min, max; + int count; + long sum; - Stat(int v) { - sum = min = max = v; - count = 1; - } - - void add(int val) { - min = Math.min(val, min); - max = Math.max(val, max); - sum += val; - count++; - } - - void merge(Stat other) { - min = Math.min(other.min, min); - max = Math.max(other.max, max); - sum += other.sum; - count += other.count; + String key() { + byte[] sbuf = new byte[MAX_STR_LEN]; + int keylen = (int) (tail >>> 56); + UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen); + return new String(sbuf, 0, keylen, StandardCharsets.UTF_8); } public String toString() { return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1); } - } - static class Node { - long addr; - int keylen; - int hash; - long[] buf = new long[13]; - Stat stat; - - String key() { - byte[] buf = new byte[MAX_STR_LEN]; - UNSAFE.copyMemory(null, addr, buf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen); - return new String(buf, 0, keylen, StandardCharsets.UTF_8); - } - - Node(long a, int kl, int h, int v, long[] b) { - stat = new Stat(v); + Node(long a, long t, int val) { addr = a; - keylen = kl; - hash = h; - System.arraycopy(b, 0, buf, 0, Math.ceilDiv(kl, 8)); + tail = t; + sum = min = max = val; + count = 1; } - boolean contentEquals(final long[] other_buf) { - int k = keylen / 8; - int r = keylen % 8; - // Since the city name is most likely shorter than 16 characters - // this should be faster than typical conditional checks - long sum = 0; - for (int i = 0; i < k; i++) { - sum += buf[i] ^ other_buf[i]; + void add(int val) { + min = Math.min(min, val); + max = Math.max(max, val); + sum += val; + count++; + } + + void merge(Node other) { + min = Math.min(min, other.min); + max = Math.max(max, other.max); + sum += other.sum; + count += other.count; + } + + boolean contentEquals(long other_addr, long other_tail) { + if (tail != other_tail) // compare tail & length at the same time + return false; + long my_addr = addr; + int nl = (int) (tail >> 59); + for (int i = 0; i < nl; i++, my_addr += 8, other_addr += 8) { + if (UNSAFE.getLong(my_addr) != UNSAFE.getLong(other_addr)) + return false; } - sum += (buf[k] ^ other_buf[k]) & HASH_MASKS[r]; - return sum == 0; + return true; } } @@ -135,55 +123,83 @@ public class CalculateAverage_abeobk { return ptrs; } + static final long getSemiPosCode(final long word) { + long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;; + return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L); + } + + // very low collision mixer + // idea from https://github.com/Cyan4973/xxHash/tree/dev + // zero collision on test data + static final int xxh32(long hash) { + final int p1 = 0x85EBCA77; // prime + final int p2 = 0xC2B2AE3D; // prime + int low = (int) hash; + int high = (int) (hash >>> 32); + low ^= low >> 15; + low *= p1; + high ^= high >> 13; + high *= p2; + var h = low ^ high; + return h; + } + public static void main(String[] args) throws InterruptedException, IOException { - int cpu_cnt = Runtime.getRuntime().availableProcessors(); try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); long file_size = file.size(); long end_addr = start_addr + file_size; + + // only use all cpus on large file + int cpu_cnt = file_size < 1e6 ? 1 : Runtime.getRuntime().availableProcessors(); long chunk_size = Math.ceilDiv(file_size, cpu_cnt); // processing var threads = new Thread[cpu_cnt]; var maps = new Node[cpu_cnt][]; var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); + int[] cls = new int[cpu_cnt]; for (int i = 0; i < cpu_cnt; i++) { int thread_id = i; long start = ptrs[i]; long end = ptrs[i + 1]; - maps[i] = new Node[BUCKET_SIZE + 16]; // extra space for collisions + maps[i] = new Node[BUCKET_SIZE + 10000]; // extra space for collisions (threads[i] = new Thread(() -> { long addr = start; var map = maps[thread_id]; - long[] buf = new long[13]; // parse loop while (addr < end) { - int idx = 0; long hash = 0; long word = 0; long row_addr = addr; int semi_pos = 8; - while (semi_pos == 8) { + word = UNSAFE.getLong(addr); + long semipos_code = getSemiPosCode(word); + + while (semipos_code == 0) { + hash ^= word; + addr += 8; word = UNSAFE.getLong(addr); - buf[idx++] = word; - // idea from thomaswue & royvanrijn - long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;; - long semipos_code = (xor_semi - 0x0101010101010101L) & ~xor_semi & 0x8080808080808080L; - semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos; - hash ^= word & HASH_MASKS[semi_pos]; + semipos_code = getSemiPosCode(word); } - int hash32 = (int) (hash ^ (hash >>> 31)); - int keylen = (int) (addr - row_addr); + semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + long tail = word & HASH_MASKS[semi_pos]; + hash ^= tail; + addr += semi_pos; + + int hash32 = xxh32(hash); + long keylen = (addr - row_addr); + tail = tail | (keylen << 56); + + addr++; // great idea from merykitty (Quan Anh Mai) - long num_word = UNSAFE.getLong(++addr); + long num_word = UNSAFE.getLong(addr); int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); addr += (dot_pos >>> 3) + 3; - int shift = 28 - dot_pos; long signed = (~num_word << 59) >> 63; long dsmask = ~(signed & 0xFF); @@ -195,14 +211,16 @@ public class CalculateAverage_abeobk { while (true) { var node = map[bucket]; if (node == null) { - map[bucket] = new Node(row_addr, keylen, hash32, val, buf); + map[bucket] = new Node(row_addr, tail, val); break; } - if (node.keylen == keylen && node.hash == hash32 && node.contentEquals(buf)) { - node.stat.add(val); + if (node.contentEquals(row_addr, tail)) { + node.add(val); break; } bucket++; + if (SHOW_COLLISIONS) + cls[thread_id]++; } } })).start(); @@ -212,19 +230,26 @@ public class CalculateAverage_abeobk { for (var thread : threads) thread.join(); + if (SHOW_COLLISIONS) { + for (int i = 0; i < cpu_cnt; i++) { + System.out.println("thread-" + i + " collision = " + cls[i]); + } + } + // collect results - TreeMap ms = new TreeMap<>(); + TreeMap ms = new TreeMap<>(); for (var map : maps) { for (var node : map) { if (node == null) continue; - var stat = ms.putIfAbsent(node.key(), node.stat); + var stat = ms.putIfAbsent(node.key(), node); if (stat != null) - stat.merge(node.stat); + stat.merge(node); } } - System.out.println(ms); + if (!SHOW_COLLISIONS) + System.out.println(ms); } } } \ No newline at end of file