From 5ba094c8fded54677e787220e352a1baf74cacec Mon Sep 17 00:00:00 2001 From: Artsiom Korzun <72259616+artsiomkorzun@users.noreply.github.com> Date: Mon, 29 Jan 2024 20:36:25 +0100 Subject: [PATCH] loop similar to thomas (#634) --- .../CalculateAverage_artsiomkorzun.java | 302 ++++++++++-------- 1 file changed, 162 insertions(+), 140 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java index 2a1a387..c0cc8f9 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java @@ -33,8 +33,9 @@ import java.util.concurrent.atomic.AtomicReference; public class CalculateAverage_artsiomkorzun { private static final Path FILE = Path.of("./measurements.txt"); - private static final long SEGMENT_SIZE = 4 * 1024 * 1024; + private static final long SEGMENT_SIZE = 2 * 1024 * 1024; private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL; + private static final long LINE_PATTERN = 0x0A0A0A0A0A0A0A0AL; private static final long DOT_BITS = 0x10101000; private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); @@ -162,14 +163,14 @@ public class CalculateAverage_artsiomkorzun { return Math.round(v) / 10.0; } - private record Aggregate(int min, int max, long sum, int cnt) { + private record Aggregate(long min, long max, long sum, long cnt) { } private static class Aggregates { - private static final int ENTRIES = 64 * 1024; - private static final int SIZE = 128 * ENTRIES; - private static final int MASK = (ENTRIES - 1) << 7; + private static final long ENTRIES = 64 * 1024; + private static final long SIZE = 256 * ENTRIES; + private static final long MASK = (ENTRIES - 1) << 8; private final long pointer; @@ -179,27 +180,27 @@ public class CalculateAverage_artsiomkorzun { UNSAFE.setMemory(pointer, SIZE, (byte) 0); } - public long find(long word, int hash) { + public long find(long word, long hash) { long address = pointer + offset(hash); - long w = word(address + 24); + long w = word(address + 48); return (w == word) ? address : 0; } - public long find(long word1, long word2, int hash) { + public long find(long word1, long word2, long hash) { long address = pointer + offset(hash); - long w1 = word(address + 24); - long w2 = word(address + 32); + long w1 = word(address + 48); + long w2 = word(address + 56); return (word1 == w1) && (word2 == w2) ? address : 0; } - public long put(long reference, long word, int length, int hash) { - for (int offset = offset(hash);; offset = next(offset)) { + public long put(long reference, long word, long length, long hash) { + for (long offset = offset(hash);; offset = next(offset)) { long address = pointer + offset; - if (equal(reference, word, address + 24, length)) { + if (equal(reference, word, address + 48, length)) { return address; } - int len = UNSAFE.getInt(address); + long len = UNSAFE.getLong(address); if (len == 0) { alloc(reference, length, hash, address); return address; @@ -207,55 +208,55 @@ public class CalculateAverage_artsiomkorzun { } } - public static void update(long address, int value) { - long sum = UNSAFE.getLong(address + 8) + value; - int cnt = UNSAFE.getInt(address + 16) + 1; - short min = UNSAFE.getShort(address + 20); - short max = UNSAFE.getShort(address + 22); + public static void update(long address, long value) { + long sum = UNSAFE.getLong(address + 16) + value; + long cnt = UNSAFE.getLong(address + 24) + 1; + long min = UNSAFE.getLong(address + 32); + long max = UNSAFE.getLong(address + 40); - UNSAFE.putLong(address + 8, sum); - UNSAFE.putInt(address + 16, cnt); + UNSAFE.putLong(address + 16, sum); + UNSAFE.putLong(address + 24, cnt); if (value < min) { - UNSAFE.putShort(address + 20, (short) value); + UNSAFE.putLong(address + 32, value); } if (value > max) { - UNSAFE.putShort(address + 22, (short) value); + UNSAFE.putLong(address + 40, value); } } public void merge(Aggregates rights) { - for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 128) { + for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 256) { long rightAddress = rights.pointer + rightOffset; - int length = UNSAFE.getInt(rightAddress); + long length = UNSAFE.getLong(rightAddress); if (length == 0) { continue; } - int hash = UNSAFE.getInt(rightAddress + 4); + long hash = UNSAFE.getLong(rightAddress + 8); - for (int offset = offset(hash);; offset = next(offset)) { + for (long offset = offset(hash);; offset = next(offset)) { long address = pointer + offset; - if (equal(address + 24, rightAddress + 24, length)) { - long sum = UNSAFE.getLong(address + 8) + UNSAFE.getLong(rightAddress + 8); - int cnt = UNSAFE.getInt(address + 16) + UNSAFE.getInt(rightAddress + 16); - short min = (short) Math.min(UNSAFE.getShort(address + 20), UNSAFE.getShort(rightAddress + 20)); - short max = (short) Math.max(UNSAFE.getShort(address + 22), UNSAFE.getShort(rightAddress + 22)); + if (equal(address + 48, rightAddress + 48, length)) { + long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16); + long cnt = UNSAFE.getLong(address + 24) + UNSAFE.getLong(rightAddress + 24); + long min = Math.min(UNSAFE.getLong(address + 32), UNSAFE.getLong(rightAddress + 32)); + long max = Math.max(UNSAFE.getLong(address + 40), UNSAFE.getLong(rightAddress + 40)); - UNSAFE.putLong(address + 8, sum); - UNSAFE.putInt(address + 16, cnt); - UNSAFE.putShort(address + 20, min); - UNSAFE.putShort(address + 22, max); + UNSAFE.putLong(address + 16, sum); + UNSAFE.putLong(address + 24, cnt); + UNSAFE.putLong(address + 32, min); + UNSAFE.putLong(address + 40, max); break; } - int len = UNSAFE.getInt(address); + long len = UNSAFE.getLong(address); if (len == 0) { - UNSAFE.copyMemory(rightAddress, address, length + 24); + UNSAFE.copyMemory(rightAddress, address, length + 48); break; } } @@ -265,19 +266,19 @@ public class CalculateAverage_artsiomkorzun { public Map aggregate() { TreeMap set = new TreeMap<>(); - for (int offset = 0; offset < SIZE; offset += 128) { + for (long offset = 0; offset < SIZE; offset += 256) { long address = pointer + offset; - int length = UNSAFE.getInt(address); + long length = UNSAFE.getLong(address); if (length != 0) { - byte[] array = new byte[length - 1]; - UNSAFE.copyMemory(null, address + 24, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, array.length); + byte[] array = new byte[(int) length - 1]; + UNSAFE.copyMemory(null, address + 48, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, array.length); String key = new String(array); - long sum = UNSAFE.getLong(address + 8); - int cnt = UNSAFE.getInt(address + 16); - short min = UNSAFE.getShort(address + 20); - short max = UNSAFE.getShort(address + 22); + long sum = UNSAFE.getLong(address + 16); + long cnt = UNSAFE.getLong(address + 24); + long min = UNSAFE.getLong(address + 32); + long max = UNSAFE.getLong(address + 40); Aggregate aggregate = new Aggregate(min, max, sum, cnt); set.put(key, aggregate); @@ -287,23 +288,23 @@ public class CalculateAverage_artsiomkorzun { return set; } - private static void alloc(long reference, int length, int hash, long address) { - UNSAFE.putInt(address, length); - UNSAFE.putInt(address + 4, hash); - UNSAFE.putShort(address + 20, Short.MAX_VALUE); - UNSAFE.putShort(address + 22, Short.MIN_VALUE); - UNSAFE.copyMemory(reference, address + 24, length); + private static void alloc(long reference, long length, long hash, long address) { + UNSAFE.putLong(address, length); + UNSAFE.putLong(address + 8, hash); + UNSAFE.putLong(address + 32, Long.MAX_VALUE); + UNSAFE.putLong(address + 40, Long.MIN_VALUE); + UNSAFE.copyMemory(reference, address + 48, length); } - private static int offset(int hash) { + private static long offset(long hash) { return hash & MASK; } - private static int next(int prev) { - return (prev + 128) & (SIZE - 1); + private static long next(long prev) { + return (prev + 256) & (SIZE - 1); } - private static boolean equal(long leftAddress, long leftWord, long rightAddress, int length) { + private static boolean equal(long leftAddress, long leftWord, long rightAddress, long length) { while (length > 8) { long left = UNSAFE.getLong(leftAddress); long right = UNSAFE.getLong(rightAddress); @@ -320,7 +321,7 @@ public class CalculateAverage_artsiomkorzun { return leftWord == word(rightAddress); } - private static boolean equal(long leftAddress, long rightAddress, int length) { + private static boolean equal(long leftAddress, long rightAddress, long length) { do { long left = UNSAFE.getLong(leftAddress); long right = UNSAFE.getLong(rightAddress); @@ -362,7 +363,7 @@ public class CalculateAverage_artsiomkorzun { for (int segment; (segment = counter.getAndIncrement()) < segmentCount;) { long position = SEGMENT_SIZE * segment; - long size = Math.min(SEGMENT_SIZE, fileSize - position - 1); + long size = Math.min(SEGMENT_SIZE + 1, fileSize - position); long start = fileAddress + position; long end = start + size; @@ -374,7 +375,55 @@ public class CalculateAverage_artsiomkorzun { long left = next(start + chunk); long right = next(start + chunk + chunk); - aggregate(aggregates, start, left - 1, left, right - 1, right, end); + Chunk chunk1 = new Chunk(start, left); + Chunk chunk2 = new Chunk(left, right); + Chunk chunk3 = new Chunk(right, end); + + while (chunk1.has() && chunk2.has() && chunk3.has()) { + long word1 = word(chunk1.position); + long word2 = word(chunk2.position); + long word3 = word(chunk3.position); + + long separator1 = separator(word1); + long separator2 = separator(word2); + long separator3 = separator(word3); + + long pointer1 = find(aggregates, chunk1, word1, separator1); + long pointer2 = find(aggregates, chunk2, word2, separator2); + long pointer3 = find(aggregates, chunk3, word3, separator3); + + long value1 = value(chunk1); + long value2 = value(chunk2); + long value3 = value(chunk3); + + Aggregates.update(pointer1, value1); + Aggregates.update(pointer2, value2); + Aggregates.update(pointer3, value3); + } + + while (chunk1.has()) { + long word1 = word(chunk1.position); + long separator1 = separator(word1); + long pointer1 = find(aggregates, chunk1, word1, separator1); + long value1 = value(chunk1); + Aggregates.update(pointer1, value1); + } + + while (chunk2.has()) { + long word2 = word(chunk2.position); + long separator2 = separator(word2); + long pointer2 = find(aggregates, chunk2, word2, separator2); + long value2 = value(chunk2); + Aggregates.update(pointer2, value2); + } + + while (chunk3.has()) { + long word3 = word(chunk3.position); + long separator3 = separator(word3); + long pointer3 = find(aggregates, chunk3, word3, separator3); + long value3 = value(chunk3); + Aggregates.update(pointer3, value3); + } } while (!result.compareAndSet(null, aggregates)) { @@ -387,123 +436,82 @@ public class CalculateAverage_artsiomkorzun { } private static long next(long position) { - while (UNSAFE.getByte(position++) != '\n') { - // continue - } - return position; - } + while (true) { + long word = word(position); + long match = word ^ LINE_PATTERN; + long line = (match - 0x0101010101010101L) & (~match & 0x8080808080808080L); - private static void aggregate(Aggregates aggregates, long position1, long limit1, long position2, long limit2, long position3, long limit3) { - while (position1 <= limit1 && position2 <= limit2 && position3 <= limit3) { - long word1 = word(position1); - long word2 = word(position2); - long word3 = word(position3); + if (line == 0) { + position += 8; + continue; + } - long separator1 = separator(word1); - long separator2 = separator(word2); - long separator3 = separator(word3); - - position1 = process(aggregates, position1, word1, separator1); - position2 = process(aggregates, position2, word2, separator2); - position3 = process(aggregates, position3, word3, separator3); - } - - while (position1 <= limit1) { - long word1 = word(position1); - long separator1 = separator(word1); - position1 = process(aggregates, position1, word1, separator1); - } - - while (position2 <= limit2) { - long word2 = word(position2); - long separator2 = separator(word2); - position2 = process(aggregates, position2, word2, separator2); - } - - while (position3 <= limit3) { - long word3 = word(position3); - long separator3 = separator(word3); - position3 = process(aggregates, position3, word3, separator3); + return position + (Long.numberOfTrailingZeros(line) >>> 3) + 1; } } - private static long process(Aggregates aggregates, long position, long word, long separator) { - long end = position; - - int length; - int hash; - int value; + private static long find(Aggregates aggregates, Chunk chunk, long word, long separator) { + long start = chunk.position; + long hash; if (separator != 0) { - length = length(separator); word = mask(word, separator); hash = mix(word); - end += length; - long num = word(end); - int dot = dot(num); - value = value(num, dot); - end += (dot >> 3) + 3; + chunk.position += length(separator); long pointer = aggregates.find(word, hash); if (pointer != 0) { - Aggregates.update(pointer, value); - return end; + return pointer; } } else { long word0 = word; - word = word(end + 8); + word = word(start + 8); separator = separator(word); if (separator != 0) { - length = length(separator) + 8; word = mask(word, separator); hash = mix(word ^ word0); - end += length; - long num = word(end); - int dot = dot(num); - value = value(num, dot); - end += (dot >> 3) + 3; + chunk.position += length(separator) + 8; long pointer = aggregates.find(word0, word, hash); if (pointer != 0) { - Aggregates.update(pointer, value); - return end; + return pointer; } } else { - length = 16; - long h = word ^ word0; + chunk.position += 16; + hash = word ^ word0; while (true) { - word = word(end + length); + word = word(chunk.position); separator = separator(word); if (separator == 0) { - length += 8; - h ^= word; + chunk.position += 8; + hash ^= word; continue; } - length += length(separator); word = mask(word, separator); - hash = mix(h ^ word); - end += length; - - long num = word(end); - int dot = dot(num); - value = value(num, dot); - end += (dot >> 3) + 3; + hash = mix(hash ^ word); + chunk.position += length(separator); break; } } } - long pointer = aggregates.put(position, word, length, hash); - Aggregates.update(pointer, value); - return end; + long length = chunk.position - start; + return aggregates.put(start, word, length, hash); + } + + private static long value(Chunk chunk) { + long num = word(chunk.position); + long dot = dot(num); + chunk.position += (dot >> 3) + 3; + return value(num, dot); } private static long separator(long word) { @@ -516,28 +524,42 @@ public class CalculateAverage_artsiomkorzun { return word & mask; } - private static int length(long separator) { + private static long length(long separator) { return (Long.numberOfTrailingZeros(separator) >>> 3) + 1; } - private static int mix(long x) { + private static long mix(long x) { long h = x * -7046029254386353131L; h ^= h >>> 35; - return (int) h; + return h; // h ^= h >>> 32; // return (int) (h ^ h >>> 16); } - private static int dot(long num) { + private static long dot(long num) { return Long.numberOfTrailingZeros(~num & DOT_BITS); } - private static int value(long w, int dot) { + private static long value(long w, long dot) { long signed = (~w << 59) >> 63; long mask = ~(signed & 0xFF); long digits = ((w & mask) << (28 - dot)) & 0x0F000F0F00L; long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; - return (int) ((abs ^ signed) - signed); + return (abs ^ signed) - signed; } } -} + + private static class Chunk { + final long limit; + long position; + + public Chunk(long position, long limit) { + this.position = position; + this.limit = limit; + } + + boolean has() { + return position < limit; + } + } +} \ No newline at end of file