diff --git a/calculate_average_artsiomkorzun.sh b/calculate_average_artsiomkorzun.sh index eaf050c..96e3467 100755 --- a/calculate_average_artsiomkorzun.sh +++ b/calculate_average_artsiomkorzun.sh @@ -15,5 +15,5 @@ # limitations under the License. # -JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC" +JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_artsiomkorzun diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java index c9b7144..4f6c8fd 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java @@ -35,7 +35,7 @@ public class CalculateAverage_artsiomkorzun { private static final MemorySegment MAPPED_FILE = map(FILE); private static final int PARALLELISM = Runtime.getRuntime().availableProcessors(); - private static final int SEGMENT_SIZE = 16 * 1024 * 1024; + private static final int SEGMENT_SIZE = 32 * 1024 * 1024; private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE); private static final int SEGMENT_OVERLAP = 1024; private static final long COMMA_PATTERN = pattern(';'); @@ -100,16 +100,6 @@ public class CalculateAverage_artsiomkorzun { return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56); } - private static long getLongBigEndian(long address) { - long value = UNSAFE.getLong(address); - - if (BYTE_ORDER == ByteOrder.LITTLE_ENDIAN) { - value = Long.reverseBytes(value); - } - - return value; - } - private static long getLongLittleEndian(long address) { long value = UNSAFE.getLong(address); @@ -144,98 +134,80 @@ public class CalculateAverage_artsiomkorzun { return Math.round(v) / 10.0; } - private static class Row { - long address; - int length; - int hash; - int value; - } - private record Aggregate(int min, int max, long sum, int cnt) { } private static class Aggregates { - private static final int SIZE = 16 * 1024; + private static final int ENTRIES = 64 * 1024; + private static final int SIZE = 32 * ENTRIES; + private final long pointer; public Aggregates() { - int size = 32 * SIZE; - long address = UNSAFE.allocateMemory(size + 8096); + long address = UNSAFE.allocateMemory(SIZE + 8096); pointer = (address + 4095) & (~4095); - UNSAFE.setMemory(pointer, size, (byte) 0); - - long word = pack(Short.MAX_VALUE, Short.MIN_VALUE, 0); - for (int i = 0; i < SIZE; i++) { - long entry = pointer + 32 * i; - UNSAFE.putLong(entry + 24, word); - } + UNSAFE.setMemory(pointer, SIZE, (byte) 0); } - public void add(Row row) { - long index = index(row.hash); - long header = ((long) row.hash << 32) | (row.length); + public void add(long reference, int length, int hash, int value) { + for (int offset = offset(hash);; offset = next(offset)) { + long address = pointer + offset; + long ref = UNSAFE.getLong(address); - while (true) { - long address = pointer + (index << 5); - long head = UNSAFE.getLong(address); - long ref = UNSAFE.getLong(address + 8); - boolean isHit = (head == 0) || (head == header && equal(ref, row.address, row.length)); - - if (isHit) { - long sum = UNSAFE.getLong(address + 16) + row.value; - long word = UNSAFE.getLong(address + 24); - int min = Math.min(min(word), row.value); - int max = Math.max(max(word), row.value); - int cnt = cnt(word) + 1; - - UNSAFE.putLong(address, header); - UNSAFE.putLong(address + 8, row.address); - UNSAFE.putLong(address + 16, sum); - UNSAFE.putLong(address + 24, pack(min, max, cnt)); + if (ref == 0) { + alloc(reference, length, hash, value, address); break; } - index = (index + 1) & (SIZE - 1); + if (equal(ref, reference, length)) { + long sum = UNSAFE.getLong(address + 16) + value; + int cnt = UNSAFE.getInt(address + 24) + 1; + short min = (short) Math.min(UNSAFE.getShort(address + 28), value); + short max = (short) Math.max(UNSAFE.getShort(address + 30), value); + + UNSAFE.putLong(address + 16, sum); + UNSAFE.putInt(address + 24, cnt); + UNSAFE.putShort(address + 28, min); + UNSAFE.putShort(address + 30, max); + break; + } } } public void merge(Aggregates rights) { - for (int rightIndex = 0; rightIndex < SIZE; rightIndex++) { - long rightAddress = rights.pointer + (rightIndex << 5); - long header = UNSAFE.getLong(rightAddress); - long reference = UNSAFE.getLong(rightAddress + 8); + for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 32) { + long rightAddress = rights.pointer + rightOffset; + long reference = UNSAFE.getLong(rightAddress); - if (header == 0) { + if (reference == 0) { continue; } - int hash = (int) (header >>> 32); - int length = (int) (header); - long index = index(hash); + int hash = UNSAFE.getInt(rightAddress + 8); + int length = UNSAFE.getInt(rightAddress + 12); - while (true) { - long address = pointer + (index << 5); - long head = UNSAFE.getLong(address); - long ref = UNSAFE.getLong(address + 8); - boolean isHit = (head == 0) || (head == header && equal(ref, reference, length)); + for (int offset = offset(hash);; offset = next(offset)) { + long address = pointer + offset; + long ref = UNSAFE.getLong(address); - if (isHit) { - long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16); - long left = UNSAFE.getLong(address + 24); - long right = UNSAFE.getLong(rightAddress + 24); - int min = Math.min(min(left), min(right)); - int max = Math.max(max(left), max(right)); - int cnt = cnt(left) + cnt(right); - - UNSAFE.putLong(address, header); - UNSAFE.putLong(address + 8, reference); - UNSAFE.putLong(address + 16, sum); - UNSAFE.putLong(address + 24, pack(min, max, cnt)); + if (ref == 0) { + UNSAFE.copyMemory(rightAddress, address, 32); break; } - index = (index + 1) & (SIZE - 1); + if (equal(ref, reference, length)) { + long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16); + int cnt = UNSAFE.getInt(address + 24) + UNSAFE.getInt(rightAddress + 24); + short min = (short) Math.min(UNSAFE.getShort(address + 28), UNSAFE.getShort(rightAddress + 28)); + short max = (short) Math.max(UNSAFE.getShort(address + 30), UNSAFE.getShort(rightAddress + 30)); + + UNSAFE.putLong(address + 16, sum); + UNSAFE.putInt(address + 24, cnt); + UNSAFE.putShort(address + 28, min); + UNSAFE.putShort(address + 30, max); + break; + } } } } @@ -243,68 +215,64 @@ public class CalculateAverage_artsiomkorzun { public Map aggregate() { TreeMap set = new TreeMap<>(); - for (int index = 0; index < SIZE; index++) { - long address = pointer + (index << 5); - long head = UNSAFE.getLong(address); - long ref = UNSAFE.getLong(address + 8); + for (int offset = 0; offset < SIZE; offset += 32) { + long address = pointer + offset; + long ref = UNSAFE.getLong(address); - if (head == 0) { - continue; + if (ref != 0) { + int length = UNSAFE.getInt(address + 12) - 1; + byte[] array = new byte[length]; + UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); + String key = new String(array); + + long sum = UNSAFE.getLong(address + 16); + int cnt = UNSAFE.getInt(address + 24); + short min = UNSAFE.getShort(address + 28); + short max = UNSAFE.getShort(address + 30); + + Aggregate aggregate = new Aggregate(min, max, sum, cnt); + set.put(key, aggregate); } - - int length = (int) (head); - byte[] array = new byte[length]; - UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); - String key = new String(array); - - long sum = UNSAFE.getLong(address + 16); - long word = UNSAFE.getLong(address + 24); - - Aggregate aggregate = new Aggregate(min(word), max(word), sum, cnt(word)); - set.put(key, aggregate); } return set; } - private static long pack(int min, int max, int cnt) { - return ((long) min << 48) | (((long) max & 0xFFFF) << 32) | cnt; + private static void alloc(long reference, int length, int hash, int value, long address) { + UNSAFE.putLong(address, reference); + UNSAFE.putInt(address + 8, hash); + UNSAFE.putInt(address + 12, length); + UNSAFE.putLong(address + 16, value); + UNSAFE.putInt(address + 24, 1); + UNSAFE.putShort(address + 28, (short) value); + UNSAFE.putShort(address + 30, (short) value); } - private static int cnt(long word) { - return (int) word; + private static int offset(int hash) { + return ((hash) & (ENTRIES - 1)) << 5; } - private static int max(long word) { - return (short) (word >>> 32); - } - - private static int min(long word) { - return (short) (word >>> 48); - } - - private static long index(int hash) { - return (hash ^ (hash >> 16)) & (SIZE - 1); + private static int next(int prev) { + return (prev + 32) & (SIZE - 1); } private static boolean equal(long leftAddress, long rightAddress, int length) { - int index = 0; - while (length > 8) { - long left = UNSAFE.getLong(leftAddress + index); - long right = UNSAFE.getLong(rightAddress + index); + long left = UNSAFE.getLong(leftAddress); + long right = UNSAFE.getLong(rightAddress); if (left != right) { return false; } + leftAddress += 8; + rightAddress += 8; length -= 8; - index += 8; } - int shift = 64 - (length << 3); - long left = getLongBigEndian(leftAddress + index) >>> shift; - long right = getLongBigEndian(rightAddress + index) >>> shift; + int shift = (8 - length) << 3; + long left = getLongLittleEndian(leftAddress) << shift; + long right = getLongLittleEndian(rightAddress) << shift; return (left == right); } } @@ -323,10 +291,18 @@ public class CalculateAverage_artsiomkorzun { @Override public void run() { Aggregates aggregates = new Aggregates(); - Row row = new Row(); for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) { - aggregate(aggregates, row, segment); + long position = (long) SEGMENT_SIZE * segment; + int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, MAPPED_FILE.byteSize() - position); + long address = MAPPED_FILE.address() + position; + long limit = address + Math.min(SEGMENT_SIZE, size - 1); + + if (segment > 0) { + address = next(address); + } + + aggregate(aggregates, address, limit); } while (!result.compareAndSet(null, aggregates)) { @@ -338,75 +314,62 @@ public class CalculateAverage_artsiomkorzun { } } - private static void aggregate(Aggregates aggregates, Row row, int segment) { - long position = (long) SEGMENT_SIZE * segment; - int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, MAPPED_FILE.byteSize() - position); - long address = MAPPED_FILE.address() + position; - long limit = address + Math.min(SEGMENT_SIZE, size - 1); + private static void aggregate(Aggregates aggregates, long position, long limit) { + // this parsing can produce seg fault at page boundaries + // e.g. file size is 4096 and the last entry is X=0.0, which is less than 8 bytes + // as a result a read will be split across pages, where one of them is not mapped + // but for some reason it works on my machine, leaving to investigate - if (segment > 0) { - address = next(address); - } + for (long start = position, hash = 0; position <= limit;) { + int length; // idea: royvanrijn, explanation: https://richardstartin.github.io/posts/finding-bytes + { + long word = getLongLittleEndian(position); + long match = word ^ COMMA_PATTERN; + long mask = (match - 0x0101010101010101L) & ~match & 0x8080808080808080L; - while (address <= limit) { - // this parsing can produce seg fault at page boundaries - // e.g. file size is 4096 and the last entry is X=0.0, which is less than 8 bytes - // as a result a read will be split across pages, where one of them is not mapped - // but for some reason it works on my machine, leaving to investigate - address = parseKey(address, row); - address = parseValue(address, row); - aggregates.add(row); - } - } + if (mask == 0) { + hash ^= word; + position += 8; + continue; + } - private static long next(long address) { - while (UNSAFE.getByte(address++) != '\n') { - // continue - } - return address; - } - - // idea: royvanrijn - // explanation: https://richardstartin.github.io/posts/finding-bytes - private static long parseKey(long address, Row row) { - int length = 0; - long hash = 0; - long word; - - while (true) { - word = getLongLittleEndian(address + length); - long match = word ^ COMMA_PATTERN; - long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L; - - if (mask == 0) { - hash = 71 * hash + word; - length += 8; - continue; + int bit = Long.numberOfTrailingZeros(mask); + position += (bit >>> 3) + 1; // +sep + hash ^= (word << (69 - bit)); + length = (int) (position - start); } - int bit = Long.numberOfTrailingZeros(mask); - length += (bit >>> 3); - hash = 71 * hash + (word & (0x00FFFFFFFFFFFFFFL >>> (63 - bit))); + int value; // idea: merykitty + { + long word = getLongLittleEndian(position); + long inverted = ~word; + int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); + long signed = (inverted << 59) >> 63; + long mask = ~(signed & 0xFF); + long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; + long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; + value = (int) ((abs ^ signed) - signed); + position += (dot >> 3) + 3; + } - row.address = address; - row.length = length; - row.hash = Long.hashCode(hash); + aggregates.add(start, length, mix(hash), value); - return address + length + 1; + start = position; + hash = 0; } } - // idea: merykitty - private static long parseValue(long address, Row row) { - long word = getLongLittleEndian(address); - long inverted = ~word; - int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); - long signed = (inverted << 59) >> 63; - long mask = ~(signed & 0xFF); - long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; - long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; - row.value = (int) ((abs ^ signed) - signed); - return address + (dot >> 3) + 3; + private static long next(long position) { + while (UNSAFE.getByte(position++) != '\n') { + // continue + } + return position; + } + + private static int mix(long x) { + long h = x * -7046029254386353131L; + h ^= h >>> 32; + return (int) (h ^ h >>> 16); } } }