From c228633b5753e2565e93833cf0ef8af23c66ac77 Mon Sep 17 00:00:00 2001 From: Van Phu DO Date: Sat, 27 Jan 2024 22:54:43 +0900 Subject: [PATCH] improve hard disk access locality, another 8% (#591) * improve hard disk access locality, another 8% * add some comments & credit * fixed format --- .../onebrc/CalculateAverage_abeobk.java | 339 +++++++++--------- 1 file changed, 178 insertions(+), 161 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java index ed859f3..06cbc17 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java @@ -28,18 +28,21 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.TreeMap; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.IntStream; import sun.misc.Unsafe; public class CalculateAverage_abeobk { private static final boolean SHOW_ANALYSIS = false; + private static final int CPU_CNT = Runtime.getRuntime().availableProcessors(); private static final String FILE = "./measurements.txt"; private static final int BUCKET_SIZE = 1 << 16; private static final int BUCKET_MASK = BUCKET_SIZE - 1; private static final int MAX_STR_LEN = 100; private static final int MAX_STATIONS = 10000; + private static final long CHUNK_SZ = 1 << 22; // 4MB chunk private static final Unsafe UNSAFE = initUnsafe(); private static final long[] HASH_MASKS = new long[]{ 0x0L, @@ -52,6 +55,11 @@ public class CalculateAverage_abeobk { 0xffffffffffffffL, 0xffffffffffffffffL, }; + private static AtomicInteger chunk_id = new AtomicInteger(0); + private static int chunk_cnt; + private static long start_addr, end_addr; + private static Stat[][] all_res; + private static final void debug(String s, Object... args) { System.out.println(String.format(s, args)); } @@ -153,20 +161,6 @@ public class CalculateAverage_abeobk { } } - // split into chunks - static long[] slice(long start_addr, long end_addr, long chunk_size, int cpu_cnt) { - long[] ptrs = new long[cpu_cnt + 1]; - ptrs[0] = start_addr; - for (int i = 1; i < cpu_cnt; i++) { - long addr = start_addr + i * chunk_size; - while (addr < end_addr && UNSAFE.getByte(addr++) != '\n') - ; - ptrs[i] = Math.min(addr, end_addr); - } - ptrs[cpu_cnt] = end_addr; - return ptrs; - } - // idea from royvanrijn static final long getSemiPosCode(final long word) { long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;; @@ -189,123 +183,158 @@ public class CalculateAverage_abeobk { return (short) ((abs_val ^ signed) - signed); } - // optimize for contest - // save as much slow memory access as possible - // about 50% key < 8chars, 25% key bettween 8-10 chars - // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... - static final Node[] parse(int thread_id, long start, long end) { - int cls = 0; - long addr = start; - var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions - // parse loop - while (addr < end) { - long row_addr = addr; + // Thread pool worker + static final class Worker extends Thread { + final int thread_id; - long word0 = UNSAFE.getLong(addr); - long semipos_code = getSemiPosCode(word0); - - // about 50% chance key < 8 chars - if (semipos_code != 0) { - int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos + 1; - long num_word = UNSAFE.getLong(addr); - int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); - addr += (dot_pos >>> 3) + 3; - - long tail = word0 & HASH_MASKS[semi_pos]; - int bucket = xxh32(tail) & BUCKET_MASK; - short val = parseNum(num_word, dot_pos); - - while (true) { - var node = map[bucket]; - if (node == null) { - map[bucket] = new Node(row_addr, tail, val); - break; - } - if (node.tail == tail) { - node.add(val); - break; - } - bucket++; - if (SHOW_ANALYSIS) - cls++; - } - continue; - } - - addr += 8; - long word = UNSAFE.getLong(addr); - semipos_code = getSemiPosCode(word); - // 43% chance - if (semipos_code != 0) { - int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos + 1; - long num_word = UNSAFE.getLong(addr); - int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); - addr += (dot_pos >>> 3) + 3; - - long tail = (word & HASH_MASKS[semi_pos]); - int bucket = xxh32(word0 ^ tail) & BUCKET_MASK; - short val = parseNum(num_word, dot_pos); - - while (true) { - var node = map[bucket]; - if (node == null) { - map[bucket] = new Node(row_addr, word0, tail, val); - break; - } - if (node.word0 == word0 && node.tail == tail) { - node.add(val); - break; - } - bucket++; - if (SHOW_ANALYSIS) - cls++; - } - continue; - } - - // why not going for more? tested, slower - - long hash = word0; - while (semipos_code == 0) { - hash ^= word; - addr += 8; - word = UNSAFE.getLong(addr); - semipos_code = getSemiPosCode(word); - } - - int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; - addr += semi_pos; - int keylen = (int) (addr - row_addr); - long num_word = UNSAFE.getLong(addr + 1); - int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); - addr += (dot_pos >>> 3) + 4; - - long tail = (word & HASH_MASKS[semi_pos]); - int bucket = xxh32(hash ^ tail) & BUCKET_MASK; - short val = parseNum(num_word, dot_pos); - - while (true) { - var node = map[bucket]; - if (node == null) { - map[bucket] = new Node(row_addr, word0, tail, val); - break; - } - if (node.contentEquals(row_addr, word0, tail, keylen)) { - node.add(val); - break; - } - bucket++; - if (SHOW_ANALYSIS) - cls++; - } + Worker(int i) { + thread_id = i; + this.start(); } - if (SHOW_ANALYSIS) { - debug("Thread %d collision = %d", thread_id, cls); + @Override + public void run() { + var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions + int cnt = 0; + int id; + int cls = 0; + + // process in small chunk to maintain disk locality (artsiomkorzun trick) + // but keep going instead of merging + while ((id = chunk_id.getAndIncrement()) < chunk_cnt) { + long addr = start_addr + id * CHUNK_SZ; + long end = Math.min(addr + CHUNK_SZ, end_addr); + // adjust start + if (id > 0) { + while (UNSAFE.getByte(addr++) != '\n') + ; + } + + // parse loop + // optimize for contest + // save as much slow memory access as possible + // about 50% key < 8chars, 25% key bettween 8-10 chars + // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... + while (addr < end) { + long row_addr = addr; + + long word0 = UNSAFE.getLong(addr); + long semipos_code = getSemiPosCode(word0); + + // about 50% chance key < 8 chars + if (semipos_code != 0) { + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + addr += semi_pos + 1; + long num_word = UNSAFE.getLong(addr); + int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); + addr += (dot_pos >>> 3) + 3; + + long tail = word0 & HASH_MASKS[semi_pos]; + int bucket = xxh32(tail) & BUCKET_MASK; + short val = parseNum(num_word, dot_pos); + + while (true) { + var node = map[bucket]; + if (node == null) { + map[bucket] = new Node(row_addr, tail, val); + cnt++; + break; + } + if (node.tail == tail) { + node.add(val); + break; + } + bucket++; + if (SHOW_ANALYSIS) + cls++; + } + continue; + } + + addr += 8; + long word = UNSAFE.getLong(addr); + semipos_code = getSemiPosCode(word); + // 43% chance + if (semipos_code != 0) { + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + addr += semi_pos + 1; + long num_word = UNSAFE.getLong(addr); + int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); + addr += (dot_pos >>> 3) + 3; + + long tail = (word & HASH_MASKS[semi_pos]); + int bucket = xxh32(word0 ^ tail) & BUCKET_MASK; + short val = parseNum(num_word, dot_pos); + + while (true) { + var node = map[bucket]; + if (node == null) { + map[bucket] = new Node(row_addr, word0, tail, val); + cnt++; + break; + } + if (node.word0 == word0 && node.tail == tail) { + node.add(val); + break; + } + bucket++; + if (SHOW_ANALYSIS) + cls++; + } + continue; + } + + // why not going for more? tested, slower + long hash = word0; + while (semipos_code == 0) { + hash ^= word; + addr += 8; + word = UNSAFE.getLong(addr); + semipos_code = getSemiPosCode(word); + } + + int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; + addr += semi_pos; + int keylen = (int) (addr - row_addr); + long num_word = UNSAFE.getLong(addr + 1); + int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); + addr += (dot_pos >>> 3) + 4; + + long tail = (word & HASH_MASKS[semi_pos]); + int bucket = xxh32(hash ^ tail) & BUCKET_MASK; + short val = parseNum(num_word, dot_pos); + + while (true) { + var node = map[bucket]; + if (node == null) { + map[bucket] = new Node(row_addr, word0, tail, val); + cnt++; + break; + } + if (node.contentEquals(row_addr, word0, tail, keylen)) { + node.add(val); + break; + } + bucket++; + if (SHOW_ANALYSIS) + cls++; + } + } + } + + if (SHOW_ANALYSIS) { + debug("Thread %d collision = %d", thread_id, cls); + } + + Stat[] stats = new Stat[cnt]; + int i = 0; + for (var node : map) { + if (node != null) { + stats[i++] = new Stat(node); + } + } + all_res[thread_id] = stats; } - return map; } // thomaswue trick @@ -329,44 +358,32 @@ public class CalculateAverage_abeobk { return; } - try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { - long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); - long file_size = file.size(); - long end_addr = start_addr + file_size; + var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); + long file_size = file.size(); + start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); + end_addr = start_addr + file_size; - // only use all cpus on large file - int cpu_cnt = file_size < 1e6 ? 1 : Runtime.getRuntime().availableProcessors(); - long chunk_size = Math.ceilDiv(file_size, cpu_cnt); + // only use all cpus on large file + int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT; + chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ); + all_res = new Stat[cpu_cnt][]; - // processing - var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); + List workers = IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList(); + for (var w : workers) + w.join(); - List> maps = IntStream.range(0, cpu_cnt) - .mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1])) - .map(map -> { - List stats = new ArrayList<>(); - for (var node : map) { - if (node == null) - continue; - stats.add(new Stat(node)); - } - return stats; - }) - .parallel() - .toList(); - - TreeMap ms = new TreeMap<>(); - for (var stats : maps) { - for (var s : stats) { - var stat = ms.putIfAbsent(s.key, s); - if (stat != null) - stat.node.merge(s.node); - } + // collect all results + TreeMap ms = new TreeMap<>(); + for (var res : all_res) { + for (var s : res) { + var stat = ms.putIfAbsent(s.key, s); + if (stat != null) + stat.node.merge(s.node); } - - // print result - System.out.println(ms); - System.out.close(); } + + // print output + System.out.println(ms); + System.out.close(); } } \ No newline at end of file