From 1804fc5b5f48c20c26c6e12a7a34315796f2fae3 Mon Sep 17 00:00:00 2001 From: Van Phu DO Date: Wed, 17 Jan 2024 06:31:00 +0900 Subject: [PATCH] Native build, less memory acess, improved hash mixing (#449) --- calculate_average_abeobk.sh | 11 +- prepare_abeobk.sh | 25 ++ .../onebrc/CalculateAverage_abeobk.java | 257 ++++++++++++------ 3 files changed, 213 insertions(+), 80 deletions(-) create mode 100755 prepare_abeobk.sh diff --git a/calculate_average_abeobk.sh b/calculate_average_abeobk.sh index a7b43d4..18c4c94 100755 --- a/calculate_average_abeobk.sh +++ b/calculate_average_abeobk.sh @@ -15,5 +15,12 @@ # limitations under the License. # -JAVA_OPTS="--enable-preview" -java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_abeobk +if [ -f target/CalculateAverage_abeobk_image ]; then + echo "Picking up existing native image 'target/CalculateAverage_abeobk_image', delete the file to select JVM mode." 1>&2 + target/CalculateAverage_abeobk_image +else + JAVA_OPTS="--enable-preview" + echo "Chosing to run the app in JVM mode as no native image was found, use prepare_abeobk.sh to generate." 1>&2 + java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_abeobk +fi + diff --git a/prepare_abeobk.sh b/prepare_abeobk.sh new file mode 100755 index 0000000..bf2b7b5 --- /dev/null +++ b/prepare_abeobk.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 + +# ./mvnw clean verify removes target/ and will re-trigger native image creation. +if [ ! -f target/CalculateAverage_abeobk_image ]; then + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview" + native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk +fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java index 34a5552..ec6c9e5 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java @@ -24,11 +24,12 @@ import java.nio.channels.FileChannel.MapMode; import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; +import java.util.Arrays; import java.util.TreeMap; import sun.misc.Unsafe; public class CalculateAverage_abeobk { - private static final boolean SHOW_COLLISIONS = false; + private static final boolean SHOW_ANALYSIS = false; private static final String FILE = "./measurements.txt"; private static final int BUCKET_SIZE = 1 << 16; @@ -99,13 +100,13 @@ public class CalculateAverage_abeobk { boolean contentEquals(long other_addr, long other_tail) { if (tail != other_tail) // compare tail & length at the same time return false; - long my_addr = addr; - int nl = (int) (tail >> 59); - for (int i = 0; i < nl; i++, my_addr += 8, other_addr += 8) { - if (UNSAFE.getLong(my_addr) != UNSAFE.getLong(other_addr)) - return false; + // this is faster than comparision if key is short + long xsum = 0; + int n = ((int) (tail >>> 56)) & 0xF8; + for (int i = 0; i < n; i += 8) { + xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i)); } - return true; + return xsum == 0; } } @@ -123,6 +124,7 @@ public class CalculateAverage_abeobk { return ptrs; } + // idea from royvanrijn static final long getSemiPosCode(final long word) { long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;; return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L); @@ -133,17 +135,164 @@ public class CalculateAverage_abeobk { // zero collision on test data static final int xxh32(long hash) { final int p1 = 0x85EBCA77; // prime - final int p2 = 0xC2B2AE3D; // prime + final int p2 = 0x165667B1; // prime int low = (int) hash; - int high = (int) (hash >>> 32); - low ^= low >> 15; - low *= p1; - high ^= high >> 13; - high *= p2; - var h = low ^ high; + int high = (int) (hash >>> 31); + int h = low + high; + h ^= h >> 15; + h *= p1; + h ^= h >> 13; + h *= p2; + h ^= h >> 11; return h; } + // great idea from merykitty (Quan Anh Mai) + static final int parseNum(long num_word, int dot_pos) { + int shift = 28 - dot_pos; + long signed = (~num_word << 59) >> 63; + long dsmask = ~(signed & 0xFF); + long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; + long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; + return (int) ((abs_val ^ signed) - signed); + } + + // optimize for contest + // save as much slow memory access as possible + // about 50% key < 8chars, 25% key bettween 8-10 chars + // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... + static final Node[] parse(int thread_id, long start, long end, int[] cls) { + long addr = start; + var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions + // parse loop + while (addr < end) { + long row_addr = addr; + long tail = 0; + long hash = 0; + int val = 0; + int bucket = 0; + + long word = UNSAFE.getLong(addr); + long semipos_code = getSemiPosCode(word); + + // about 50% chance key < 8 chars + if (semipos_code != 0) { + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + addr += semi_pos; + tail = (word & HASH_MASKS[semi_pos]); + bucket = xxh32(tail) & BUCKET_MASK; + long keylen = (addr - row_addr); + tail |= (keylen << 56); + long num_word = UNSAFE.getLong(++addr); + int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); + val = parseNum(num_word, dot_pos); + addr += (dot_pos >>> 3) + 3; + + while (true) { + var node = map[bucket]; + if (node == null) { + map[bucket] = new Node(row_addr, tail, val); + break; + } + if (node.tail == tail) { + node.add(val); + break; + } + bucket++; + if (SHOW_ANALYSIS) + cls[thread_id]++; + } + continue; + } + + hash ^= word; + addr += 8; + word = UNSAFE.getLong(addr); + semipos_code = getSemiPosCode(word); + // frist byte semicolon ~13% + if (semipos_code == 0x80) { + bucket = xxh32(hash) & BUCKET_MASK; + tail = 8L << 56; + long num_word = word >>> 8; + int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); + val = parseNum(num_word, dot_pos); + addr += (dot_pos >>> 3) + 4; + + while (true) { + var node = map[bucket]; + if (node == null) { + map[bucket] = new Node(row_addr, tail, val); + break; + } + if (UNSAFE.getLong(node.addr) == UNSAFE.getLong(row_addr)) { + node.add(val); + break; + } + bucket++; + if (SHOW_ANALYSIS) + cls[thread_id]++; + } + continue; + } + + while (semipos_code == 0) { + hash ^= word; + addr += 8; + word = UNSAFE.getLong(addr); + semipos_code = getSemiPosCode(word); + } + + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + addr += semi_pos; + tail = (word & HASH_MASKS[semi_pos]); + hash ^= tail; + bucket = xxh32(hash) & BUCKET_MASK; + long keylen = (addr - row_addr); + tail |= (keylen << 56); + + ++addr; + long num_word = UNSAFE.getLong(addr); + int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); + val = parseNum(num_word, dot_pos); + addr += (dot_pos >>> 3) + 3; + + if (keylen < 16) { + while (true) { + var node = map[bucket]; + if (node == null) { + map[bucket] = new Node(row_addr, tail, val); + break; + } + if (node.tail == tail && (UNSAFE.getLong(node.addr) == UNSAFE.getLong(row_addr))) { + node.add(val); + break; + } + bucket++; + if (SHOW_ANALYSIS) + cls[thread_id]++; + } + continue; + } + + // longer key + while (true) { + var node = map[bucket]; + if (node == null) { + map[bucket] = new Node(row_addr, tail, val); + break; + } + if (node.contentEquals(row_addr, tail)) { + node.add(val); + break; + } + bucket++; + if (SHOW_ANALYSIS) + cls[thread_id]++; + } + } + return map; + } + public static void main(String[] args) throws InterruptedException, IOException { try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); @@ -158,71 +307,14 @@ public class CalculateAverage_abeobk { var threads = new Thread[cpu_cnt]; var maps = new Node[cpu_cnt][]; var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); - int[] cls = new int[cpu_cnt]; + + int[] cls = new int[cpu_cnt]; // collision + int[] lenhist = new int[64]; // length histogram for (int i = 0; i < cpu_cnt; i++) { int thread_id = i; - long start = ptrs[i]; - long end = ptrs[i + 1]; - maps[i] = new Node[BUCKET_SIZE + 10000]; // extra space for collisions - - (threads[i] = new Thread(() -> { - long addr = start; - var map = maps[thread_id]; - // parse loop - while (addr < end) { - long hash = 0; - long word = 0; - long row_addr = addr; - int semi_pos = 8; - word = UNSAFE.getLong(addr); - long semipos_code = getSemiPosCode(word); - - while (semipos_code == 0) { - hash ^= word; - addr += 8; - word = UNSAFE.getLong(addr); - semipos_code = getSemiPosCode(word); - } - - semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - long tail = word & HASH_MASKS[semi_pos]; - hash ^= tail; - addr += semi_pos; - - int hash32 = xxh32(hash); - long keylen = (addr - row_addr); - tail = tail | (keylen << 56); - - addr++; - - // great idea from merykitty (Quan Anh Mai) - long num_word = UNSAFE.getLong(addr); - int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); - addr += (dot_pos >>> 3) + 3; - int shift = 28 - dot_pos; - long signed = (~num_word << 59) >> 63; - long dsmask = ~(signed & 0xFF); - long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; - long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; - int val = (int) ((abs_val ^ signed) - signed); - - int bucket = (hash32 & BUCKET_MASK); - while (true) { - var node = map[bucket]; - if (node == null) { - map[bucket] = new Node(row_addr, tail, val); - break; - } - if (node.contentEquals(row_addr, tail)) { - node.add(val); - break; - } - bucket++; - if (SHOW_COLLISIONS) - cls[thread_id]++; - } - } + (threads[thread_id] = new Thread(() -> { + maps[thread_id] = parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1], cls); })).start(); } @@ -230,7 +322,7 @@ public class CalculateAverage_abeobk { for (var thread : threads) thread.join(); - if (SHOW_COLLISIONS) { + if (SHOW_ANALYSIS) { for (int i = 0; i < cpu_cnt; i++) { System.out.println("thread-" + i + " collision = " + cls[i]); } @@ -242,13 +334,22 @@ public class CalculateAverage_abeobk { for (var node : map) { if (node == null) continue; + if (SHOW_ANALYSIS) { + int kl = (int) (node.tail >>> 56) & (lenhist.length - 1); + lenhist[kl] += node.count; + } var stat = ms.putIfAbsent(node.key(), node); if (stat != null) stat.merge(node); } } - if (!SHOW_COLLISIONS) + if (SHOW_ANALYSIS) { + System.out.println("total=" + Arrays.stream(lenhist).sum()); + System.out.println("length_histogram = " + + Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray())); + } + else System.out.println(ms); } }