From a33ed2181b0cc71882e009e5f445d36009e3b07c Mon Sep 17 00:00:00 2001 From: Van Phu DO Date: Mon, 29 Jan 2024 02:08:42 +0900 Subject: [PATCH] Use native type, remove lots of type conversions (#618) * less type conversion, less string cast * adjust some comments * fixed format issue --- .../onebrc/CalculateAverage_abeobk.java | 179 ++++++++++-------- 1 file changed, 99 insertions(+), 80 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java index 06cbc17..c08a9d8 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java @@ -26,9 +26,9 @@ import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.TreeMap; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; import sun.misc.Unsafe; @@ -39,7 +39,7 @@ public class CalculateAverage_abeobk { private static final String FILE = "./measurements.txt"; private static final int BUCKET_SIZE = 1 << 16; - private static final int BUCKET_MASK = BUCKET_SIZE - 1; + private static final long BUCKET_MASK = BUCKET_SIZE - 1; private static final int MAX_STR_LEN = 100; private static final int MAX_STATIONS = 10000; private static final long CHUNK_SZ = 1 << 22; // 4MB chunk @@ -56,9 +56,9 @@ public class CalculateAverage_abeobk { 0xffffffffffffffffL, }; private static AtomicInteger chunk_id = new AtomicInteger(0); + private static AtomicReference mapref = new AtomicReference<>(null); private static int chunk_cnt; private static long start_addr, end_addr; - private static Stat[][] all_res; private static final void debug(String s, Object... args) { System.out.println(String.format(s, args)); @@ -75,57 +75,49 @@ public class CalculateAverage_abeobk { } } - static class Stat { - Node node; - String key; - - public final String toString() { - return (node.min / 10.0) + "/" - + (Math.round(((double) node.sum / node.count)) / 10.0) + "/" - + (node.max / 10.0); - } - - Stat(Node n) { - node = n; - byte[] sbuf = new byte[MAX_STR_LEN]; - long word = UNSAFE.getLong(n.addr); - long semipos_code = getSemiPosCode(word); - int keylen = 0; - while (semipos_code == 0) { - keylen += 8; - word = UNSAFE.getLong(n.addr + keylen); - semipos_code = getSemiPosCode(word); - } - keylen += Long.numberOfTrailingZeros(semipos_code) >>> 3; - UNSAFE.copyMemory(null, n.addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen); - key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8); - } - } - + // use native type, less conversion static class Node { long addr; + long hash; long word0; long tail; long sum; + long min, max; + int keylen; int count; - short min, max; - Node(long a, long t, short val) { + public final String toString() { + return (min / 10.0) + "/" + + (Math.round(((double) sum / count)) / 10.0) + "/" + + (max / 10.0); + } + + final String key() { + byte[] sbuf = new byte[MAX_STR_LEN]; + UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen); + return new String(sbuf, 0, (int) keylen, StandardCharsets.UTF_8); + } + + Node(long a, long t, int kl, long h, long val) { addr = a; tail = t; sum = min = max = val; count = 1; + keylen = kl; + hash = h; } - Node(long a, long w0, long t, short val) { + Node(long a, long w0, long t, int kl, long h, long val) { addr = a; word0 = w0; tail = t; sum = min = max = val; count = 1; + keylen = kl; + hash = h; } - final void add(short val) { + final void add(long val) { sum += val; count++; if (val >= max) { @@ -148,17 +140,28 @@ public class CalculateAverage_abeobk { } } - final boolean contentEquals(long other_addr, long other_word0, long other_tail, int keylen) { + final boolean contentEquals(long other_addr, long other_word0, long other_tail, long kl) { if (word0 != other_word0 || tail != other_tail) return false; // this is faster than comparision if key is short long xsum = 0; - int n = keylen & 0xF8; - for (int i = 8; i < n; i += 8) { + long n = kl & 0xF8; + for (long i = 8; i < n; i += 8) { xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i)); } return xsum == 0; } + + final boolean contentEquals(Node other) { + if (tail != other.tail) + return false; + long n = keylen & 0xF8; + for (long i = 0; i < n; i += 8) { + if (UNSAFE.getLong(addr + i) != UNSAFE.getLong(other.addr + i)) + return false; + } + return true; + } } // idea from royvanrijn @@ -168,24 +171,24 @@ public class CalculateAverage_abeobk { } // speed/collision balance - static final int xxh32(long hash) { + static final long xxh32(long hash) { long h = hash * 37; - return (int) (h ^ (h >>> 29)); + return (h ^ (h >>> 29)); } // great idea from merykitty (Quan Anh Mai) - static final short parseNum(long num_word, int dot_pos) { + static final long parseNum(long num_word, int dot_pos) { int shift = 28 - dot_pos; long signed = (~num_word << 59) >> 63; long dsmask = ~(signed & 0xFF); long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; - return (short) ((abs_val ^ signed) - signed); + return ((abs_val ^ signed) - signed); } // Thread pool worker static final class Worker extends Thread { - final int thread_id; + final int thread_id; // for debug use only Worker(int i) { thread_id = i; @@ -195,16 +198,15 @@ public class CalculateAverage_abeobk { @Override public void run() { var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions - int cnt = 0; int id; int cls = 0; // process in small chunk to maintain disk locality (artsiomkorzun trick) - // but keep going instead of merging while ((id = chunk_id.getAndIncrement()) < chunk_cnt) { long addr = start_addr + id * CHUNK_SZ; long end = Math.min(addr + CHUNK_SZ, end_addr); - // adjust start + + // find start of line if (id > 0) { while (UNSAFE.getByte(addr++) != '\n') ; @@ -230,14 +232,14 @@ public class CalculateAverage_abeobk { addr += (dot_pos >>> 3) + 3; long tail = word0 & HASH_MASKS[semi_pos]; - int bucket = xxh32(tail) & BUCKET_MASK; - short val = parseNum(num_word, dot_pos); + long hash = xxh32(tail); + int bucket = (int) (hash & BUCKET_MASK); + long val = parseNum(num_word, dot_pos); while (true) { var node = map[bucket]; if (node == null) { - map[bucket] = new Node(row_addr, tail, val); - cnt++; + map[bucket] = new Node(row_addr, tail, semi_pos, hash, val); break; } if (node.tail == tail) { @@ -263,14 +265,14 @@ public class CalculateAverage_abeobk { addr += (dot_pos >>> 3) + 3; long tail = (word & HASH_MASKS[semi_pos]); - int bucket = xxh32(word0 ^ tail) & BUCKET_MASK; - short val = parseNum(num_word, dot_pos); + long hash = xxh32(word0 ^ tail); + int bucket = (int) (hash & BUCKET_MASK); + long val = parseNum(num_word, dot_pos); while (true) { var node = map[bucket]; if (node == null) { - map[bucket] = new Node(row_addr, word0, tail, val); - cnt++; + map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash, val); break; } if (node.word0 == word0 && node.tail == tail) { @@ -295,20 +297,20 @@ public class CalculateAverage_abeobk { int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; addr += semi_pos; - int keylen = (int) (addr - row_addr); + long keylen = addr - row_addr; long num_word = UNSAFE.getLong(addr + 1); int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); addr += (dot_pos >>> 3) + 4; long tail = (word & HASH_MASKS[semi_pos]); - int bucket = xxh32(hash ^ tail) & BUCKET_MASK; - short val = parseNum(num_word, dot_pos); + hash = xxh32(hash ^ tail); + int bucket = (int) (hash & BUCKET_MASK); + long val = parseNum(num_word, dot_pos); while (true) { var node = map[bucket]; if (node == null) { - map[bucket] = new Node(row_addr, word0, tail, val); - cnt++; + map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash, val); break; } if (node.contentEquals(row_addr, word0, tail, keylen)) { @@ -322,18 +324,36 @@ public class CalculateAverage_abeobk { } } + // merge is cheaper than string casting (artsiomkorzun) + while (!mapref.compareAndSet(null, map)) { + var other_map = mapref.getAndSet(null); + if (other_map != null) { + for (int i = 0; i < other_map.length; i++) { + var other = other_map[i]; + if (other == null) + continue; + int bucket = (int) (other.hash & BUCKET_MASK); + while (true) { + var node = map[bucket]; + if (node == null) { + map[bucket] = other; + break; + } + if (node.contentEquals(other)) { + node.merge(other); + break; + } + bucket++; + if (SHOW_ANALYSIS) + cls++; + } + } + } + } + if (SHOW_ANALYSIS) { debug("Thread %d collision = %d", thread_id, cls); } - - Stat[] stats = new Stat[cnt]; - int i = 0; - for (var node : map) { - if (node != null) { - stats[i++] = new Stat(node); - } - } - all_res[thread_id] = stats; } } @@ -366,23 +386,22 @@ public class CalculateAverage_abeobk { // only use all cpus on large file int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT; chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ); - all_res = new Stat[cpu_cnt][]; - List workers = IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList(); - for (var w : workers) + // spawn workers + for (var w : IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList()) { w.join(); - - // collect all results - TreeMap ms = new TreeMap<>(); - for (var res : all_res) { - for (var s : res) { - var stat = ms.putIfAbsent(s.key, s); - if (stat != null) - stat.node.merge(s.node); - } } - // print output + // collect results + TreeMap ms = new TreeMap<>(); + for (var crr : mapref.get()) { + if (crr == null) + continue; + var prev = ms.putIfAbsent(crr.key(), crr); + if (prev != null) + prev.merge(crr); + } + // print result System.out.println(ms); System.out.close(); }