From 8e407ca79dc9c2f51b096f95687306103258bf75 Mon Sep 17 00:00:00 2001 From: Van Phu DO Date: Tue, 30 Jan 2024 05:21:04 +0900 Subject: [PATCH] apply loop unroll trick (#643) * apply loop unroll trick * less assign op, a bit faster --- prepare_abeobk.sh | 4 +- .../onebrc/CalculateAverage_abeobk.java | 304 ++++++++++-------- 2 files changed, 179 insertions(+), 129 deletions(-) diff --git a/prepare_abeobk.sh b/prepare_abeobk.sh index 1b73743..08a8afd 100755 --- a/prepare_abeobk.sh +++ b/prepare_abeobk.sh @@ -20,8 +20,6 @@ sdk use java 21.0.2-graal 1>&2 # ./mvnw clean verify removes target/ and will re-trigger native image creation. if [ ! -f target/CalculateAverage_abeobk_image ]; then - NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -dsa -march=native -R:MaxHeapSize=128m -H:-GenLoopSafepoints -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -dsa -march=native -H:InlineAllBonus=10 -H:-GenLoopSafepoints -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk fi - - diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java index c08a9d8..2340bca 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java @@ -98,21 +98,21 @@ public class CalculateAverage_abeobk { return new String(sbuf, 0, (int) keylen, StandardCharsets.UTF_8); } - Node(long a, long t, int kl, long h, long val) { + Node(long a, long t, int kl, long h) { addr = a; tail = t; - sum = min = max = val; - count = 1; + min = 999; + max = -999; keylen = kl; hash = h; } - Node(long a, long w0, long t, int kl, long h, long val) { + Node(long a, long w0, long t, int kl, long h) { addr = a; word0 = w0; + min = 999; + max = -999; tail = t; - sum = min = max = val; - count = 1; keylen = kl; hash = h; } @@ -120,9 +120,8 @@ public class CalculateAverage_abeobk { final void add(long val) { sum += val; count++; - if (val >= max) { + if (val > max) { max = val; - return; } if (val < min) { min = val; @@ -170,25 +169,141 @@ public class CalculateAverage_abeobk { return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L); } + static final long getLFCode(final long word) { + long xor_semi = word ^ 0x0A0A0A0A0A0A0A0AL; // xor with \n\n\n\n\n\n\n\n + return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L); + } + + static final long nextLine(long addr) { + long word = UNSAFE.getLong(addr); + long lfpos_code = getLFCode(word); + while (lfpos_code == 0) { + addr += 8; + word = UNSAFE.getLong(addr); + lfpos_code = getLFCode(word); + } + return addr + (Long.numberOfTrailingZeros(lfpos_code) >>> 3) + 1; + } + // speed/collision balance static final long xxh32(long hash) { long h = hash * 37; return (h ^ (h >>> 29)); } - // great idea from merykitty (Quan Anh Mai) - static final long parseNum(long num_word, int dot_pos) { - int shift = 28 - dot_pos; - long signed = (~num_word << 59) >> 63; - long dsmask = ~(signed & 0xFF); - long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; - long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; - return ((abs_val ^ signed) - signed); + static final class ChunkParser { + long addr; + long end; + Node[] map; + + ChunkParser(Node[] m, long a, long e) { + map = m; + addr = a; + end = e; + } + + final boolean ok() { + return addr < end; + } + + final long word() { + return UNSAFE.getLong(addr); + } + + final long val() { + long num_word = UNSAFE.getLong(addr); + int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); + addr += (dot_pos >>> 3) + 3; + // great idea from merykitty (Quan Anh Mai) + int shift = 28 - dot_pos; + long signed = (~num_word << 59) >> 63; + long dsmask = ~(signed & 0xFF); + long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; + long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; + return ((abs_val ^ signed) - signed); + } + + // optimize for contest + // save as much slow memory access as possible + // about 50% key < 8chars, 25% key bettween 8-10 chars + // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... + final Node key(long word0, long semipos_code) { + long row_addr = addr; + // about 50% chance key < 8 chars + if (semipos_code != 0) { + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + addr += semi_pos + 1; + long tail = word0 & HASH_MASKS[semi_pos]; + long hash = xxh32(tail); + int bucket = (int) (hash & BUCKET_MASK); + while (true) { + Node node = map[bucket]; + if (node == null) { + return (map[bucket] = new Node(row_addr, tail, semi_pos, hash)); + } + if (node.tail == tail) { + return node; + } + bucket++; + } + } + + addr += 8; + long word = UNSAFE.getLong(addr); + semipos_code = getSemiPosCode(word); + // 43% chance + if (semipos_code != 0) { + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + addr += semi_pos + 1; + long tail = (word & HASH_MASKS[semi_pos]); + long hash = xxh32(word0 ^ tail); + int bucket = (int) (hash & BUCKET_MASK); + while (true) { + Node node = map[bucket]; + if (node == null) { + return (map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash)); + } + if (node.word0 == word0 && node.tail == tail) { + return node; + } + bucket++; + } + } + + // why not going for more? tested, slower + long hash = word0; + while (semipos_code == 0) { + hash ^= word; + addr += 8; + word = UNSAFE.getLong(addr); + semipos_code = getSemiPosCode(word); + } + + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + addr += semi_pos; + long keylen = addr - row_addr; + addr++; + long tail = (word & HASH_MASKS[semi_pos]); + hash = xxh32(hash ^ tail); + int bucket = (int) (hash & BUCKET_MASK); + + while (true) { + Node node = map[bucket]; + if (node == null) { + return (map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash)); + } + if (node.contentEquals(row_addr, word0, tail, keylen)) { + return node; + } + bucket++; + } + } } // Thread pool worker static final class Worker extends Thread { final int thread_id; // for debug use only + int cls = 0; Worker(int i) { thread_id = i; @@ -198,9 +313,8 @@ public class CalculateAverage_abeobk { @Override public void run() { var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions - int id; - int cls = 0; + int id; // process in small chunk to maintain disk locality (artsiomkorzun trick) while ((id = chunk_id.getAndIncrement()) < chunk_cnt) { long addr = start_addr + id * CHUNK_SZ; @@ -208,119 +322,57 @@ public class CalculateAverage_abeobk { // find start of line if (id > 0) { - while (UNSAFE.getByte(addr++) != '\n') - ; + addr = nextLine(addr); } - // parse loop - // optimize for contest - // save as much slow memory access as possible - // about 50% key < 8chars, 25% key bettween 8-10 chars - // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... - while (addr < end) { - long row_addr = addr; + final int num_segs = 3; + long seglen = (end - addr) / num_segs; - long word0 = UNSAFE.getLong(addr); - long semipos_code = getSemiPosCode(word0); + long a0 = addr; + long a1 = nextLine(addr + 1 * seglen); + long a2 = nextLine(addr + 2 * seglen); + ChunkParser p0 = new ChunkParser(map, a0, a1); + ChunkParser p1 = new ChunkParser(map, a1, a2); + ChunkParser p2 = new ChunkParser(map, a2, end); - // about 50% chance key < 8 chars - if (semipos_code != 0) { - int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos + 1; - long num_word = UNSAFE.getLong(addr); - int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); - addr += (dot_pos >>> 3) + 3; + while (p0.ok() && p1.ok() && p2.ok()) { + long w0 = p0.word(); + long w1 = p1.word(); + long w2 = p2.word(); + long sc0 = getSemiPosCode(w0); + long sc1 = getSemiPosCode(w1); + long sc2 = getSemiPosCode(w2); + Node n0 = p0.key(w0, sc0); + Node n1 = p1.key(w1, sc1); + Node n2 = p2.key(w2, sc2); + long v0 = p0.val(); + long v1 = p1.val(); + long v2 = p2.val(); + n0.add(v0); + n1.add(v1); + n2.add(v2); + } - long tail = word0 & HASH_MASKS[semi_pos]; - long hash = xxh32(tail); - int bucket = (int) (hash & BUCKET_MASK); - long val = parseNum(num_word, dot_pos); - - while (true) { - var node = map[bucket]; - if (node == null) { - map[bucket] = new Node(row_addr, tail, semi_pos, hash, val); - break; - } - if (node.tail == tail) { - node.add(val); - break; - } - bucket++; - if (SHOW_ANALYSIS) - cls++; - } - continue; - } - - addr += 8; - long word = UNSAFE.getLong(addr); - semipos_code = getSemiPosCode(word); - // 43% chance - if (semipos_code != 0) { - int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos + 1; - long num_word = UNSAFE.getLong(addr); - int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); - addr += (dot_pos >>> 3) + 3; - - long tail = (word & HASH_MASKS[semi_pos]); - long hash = xxh32(word0 ^ tail); - int bucket = (int) (hash & BUCKET_MASK); - long val = parseNum(num_word, dot_pos); - - while (true) { - var node = map[bucket]; - if (node == null) { - map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash, val); - break; - } - if (node.word0 == word0 && node.tail == tail) { - node.add(val); - break; - } - bucket++; - if (SHOW_ANALYSIS) - cls++; - } - continue; - } - - // why not going for more? tested, slower - long hash = word0; - while (semipos_code == 0) { - hash ^= word; - addr += 8; - word = UNSAFE.getLong(addr); - semipos_code = getSemiPosCode(word); - } - - int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos; - long keylen = addr - row_addr; - long num_word = UNSAFE.getLong(addr + 1); - int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); - addr += (dot_pos >>> 3) + 4; - - long tail = (word & HASH_MASKS[semi_pos]); - hash = xxh32(hash ^ tail); - int bucket = (int) (hash & BUCKET_MASK); - long val = parseNum(num_word, dot_pos); - - while (true) { - var node = map[bucket]; - if (node == null) { - map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash, val); - break; - } - if (node.contentEquals(row_addr, word0, tail, keylen)) { - node.add(val); - break; - } - bucket++; - if (SHOW_ANALYSIS) - cls++; - } + while (p0.ok()) { + long w = p0.word(); + long sc = getSemiPosCode(w); + Node n = p0.key(w, sc); + long v = p0.val(); + n.add(v); + } + while (p1.ok()) { + long w = p1.word(); + long sc = getSemiPosCode(w); + Node n = p1.key(w, sc); + long v = p1.val(); + n.add(v); + } + while (p2.ok()) { + long w = p2.word(); + long sc = getSemiPosCode(w); + Node n = p2.key(w, sc); + long v = p2.val(); + n.add(v); } }