diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java index 4f6c8fd..f92f414 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java @@ -20,7 +20,6 @@ import sun.misc.Unsafe; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.reflect.Field; -import java.nio.ByteOrder; import java.nio.channels.FileChannel; import java.nio.file.Path; import java.nio.file.StandardOpenOption; @@ -38,11 +37,10 @@ public class CalculateAverage_artsiomkorzun { private static final int SEGMENT_SIZE = 32 * 1024 * 1024; private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE); private static final int SEGMENT_OVERLAP = 1024; - private static final long COMMA_PATTERN = pattern(';'); + private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL; private static final long DOT_BITS = 0x10101000; private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); - private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder(); private static final Unsafe UNSAFE; static { @@ -95,19 +93,15 @@ public class CalculateAverage_artsiomkorzun { } } - private static long pattern(char c) { - long b = c & 0xFFL; - return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56); - } - - private static long getLongLittleEndian(long address) { - long value = UNSAFE.getLong(address); - - if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) { - value = Long.reverseBytes(value); - } - - return value; + private static long word(long address) { + return UNSAFE.getLong(address); + /* + * if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) { + * value = Long.reverseBytes(value); + * } + * + * return value; + */ } private static String text(Map aggregates) { @@ -140,7 +134,7 @@ public class CalculateAverage_artsiomkorzun { private static class Aggregates { private static final int ENTRIES = 64 * 1024; - private static final int SIZE = 32 * ENTRIES; + private static final int SIZE = 128 * ENTRIES; private final long pointer; @@ -150,62 +144,82 @@ public class CalculateAverage_artsiomkorzun { UNSAFE.setMemory(pointer, SIZE, (byte) 0); } - public void add(long reference, int length, int hash, int value) { + public long find(long word, int hash) { + long address = pointer + offset(hash); + long w = word(address + 24); + return (w == word) ? address : 0; + } + + public long find(long word1, long word2, int hash) { + long address = pointer + offset(hash); + long w1 = word(address + 24); + long w2 = word(address + 32); + return (word1 == w1) && (word2 == w2) ? address : 0; + } + + public long put(long reference, long word, int length, int hash) { for (int offset = offset(hash);; offset = next(offset)) { long address = pointer + offset; - long ref = UNSAFE.getLong(address); - - if (ref == 0) { - alloc(reference, length, hash, value, address); - break; + if (equal(reference, word, address + 24, length)) { + return address; } - if (equal(ref, reference, length)) { - long sum = UNSAFE.getLong(address + 16) + value; - int cnt = UNSAFE.getInt(address + 24) + 1; - short min = (short) Math.min(UNSAFE.getShort(address + 28), value); - short max = (short) Math.max(UNSAFE.getShort(address + 30), value); - - UNSAFE.putLong(address + 16, sum); - UNSAFE.putInt(address + 24, cnt); - UNSAFE.putShort(address + 28, min); - UNSAFE.putShort(address + 30, max); - break; + int len = UNSAFE.getInt(address); + if (len == 0) { + alloc(reference, length, hash, address); + return address; } } } - public void merge(Aggregates rights) { - for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 32) { - long rightAddress = rights.pointer + rightOffset; - long reference = UNSAFE.getLong(rightAddress); + public static void update(long address, int value) { + long sum = UNSAFE.getLong(address + 8) + value; + int cnt = UNSAFE.getInt(address + 16) + 1; + short min = UNSAFE.getShort(address + 20); + short max = UNSAFE.getShort(address + 22); - if (reference == 0) { + UNSAFE.putLong(address + 8, sum); + UNSAFE.putInt(address + 16, cnt); + + if (value < min) { + UNSAFE.putShort(address + 20, (short) value); + } + + if (value > max) { + UNSAFE.putShort(address + 22, (short) value); + } + } + + public void merge(Aggregates rights) { + for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 128) { + long rightAddress = rights.pointer + rightOffset; + int length = UNSAFE.getInt(rightAddress); + + if (length == 0) { continue; } - int hash = UNSAFE.getInt(rightAddress + 8); - int length = UNSAFE.getInt(rightAddress + 12); + int hash = UNSAFE.getInt(rightAddress + 4); for (int offset = offset(hash);; offset = next(offset)) { long address = pointer + offset; - long ref = UNSAFE.getLong(address); + int len = UNSAFE.getInt(address); - if (ref == 0) { - UNSAFE.copyMemory(rightAddress, address, 32); + if (len == 0) { + UNSAFE.copyMemory(rightAddress, address, 24 + length); break; } - if (equal(ref, reference, length)) { - long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16); - int cnt = UNSAFE.getInt(address + 24) + UNSAFE.getInt(rightAddress + 24); - short min = (short) Math.min(UNSAFE.getShort(address + 28), UNSAFE.getShort(rightAddress + 28)); - short max = (short) Math.max(UNSAFE.getShort(address + 30), UNSAFE.getShort(rightAddress + 30)); + if (len == length && equal(address + 24, rightAddress + 24, length)) { + long sum = UNSAFE.getLong(address + 8) + UNSAFE.getLong(rightAddress + 8); + int cnt = UNSAFE.getInt(address + 16) + UNSAFE.getInt(rightAddress + 16); + short min = (short) Math.min(UNSAFE.getShort(address + 20), UNSAFE.getShort(rightAddress + 20)); + short max = (short) Math.max(UNSAFE.getShort(address + 22), UNSAFE.getShort(rightAddress + 22)); - UNSAFE.putLong(address + 16, sum); - UNSAFE.putInt(address + 24, cnt); - UNSAFE.putShort(address + 28, min); - UNSAFE.putShort(address + 30, max); + UNSAFE.putLong(address + 8, sum); + UNSAFE.putInt(address + 16, cnt); + UNSAFE.putShort(address + 20, min); + UNSAFE.putShort(address + 22, max); break; } } @@ -215,20 +229,19 @@ public class CalculateAverage_artsiomkorzun { public Map aggregate() { TreeMap set = new TreeMap<>(); - for (int offset = 0; offset < SIZE; offset += 32) { + for (int offset = 0; offset < SIZE; offset += 128) { long address = pointer + offset; - long ref = UNSAFE.getLong(address); + int length = UNSAFE.getInt(address); - if (ref != 0) { - int length = UNSAFE.getInt(address + 12) - 1; + if (length != 0) { byte[] array = new byte[length]; - UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); + UNSAFE.copyMemory(null, address + 24, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); String key = new String(array); - long sum = UNSAFE.getLong(address + 16); - int cnt = UNSAFE.getInt(address + 24); - short min = UNSAFE.getShort(address + 28); - short max = UNSAFE.getShort(address + 30); + long sum = UNSAFE.getLong(address + 8); + int cnt = UNSAFE.getInt(address + 16); + short min = UNSAFE.getShort(address + 20); + short max = UNSAFE.getShort(address + 22); Aggregate aggregate = new Aggregate(min, max, sum, cnt); set.put(key, aggregate); @@ -238,26 +251,24 @@ public class CalculateAverage_artsiomkorzun { return set; } - private static void alloc(long reference, int length, int hash, int value, long address) { - UNSAFE.putLong(address, reference); - UNSAFE.putInt(address + 8, hash); - UNSAFE.putInt(address + 12, length); - UNSAFE.putLong(address + 16, value); - UNSAFE.putInt(address + 24, 1); - UNSAFE.putShort(address + 28, (short) value); - UNSAFE.putShort(address + 30, (short) value); + private static void alloc(long reference, int length, int hash, long address) { + UNSAFE.putInt(address, length); + UNSAFE.putInt(address + 4, hash); + UNSAFE.putShort(address + 20, Short.MAX_VALUE); + UNSAFE.putShort(address + 22, Short.MIN_VALUE); + UNSAFE.copyMemory(reference, address + 24, length); } private static int offset(int hash) { - return ((hash) & (ENTRIES - 1)) << 5; + return ((hash) & (ENTRIES - 1)) << 7; } private static int next(int prev) { - return (prev + 32) & (SIZE - 1); + return (prev + 128) & (SIZE - 1); } - private static boolean equal(long leftAddress, long rightAddress, int length) { - while (length > 8) { + private static boolean equal(long leftAddress, long leftWord, long rightAddress, int length) { + while (length >= 8) { long left = UNSAFE.getLong(leftAddress); long right = UNSAFE.getLong(rightAddress); @@ -270,10 +281,24 @@ public class CalculateAverage_artsiomkorzun { length -= 8; } - int shift = (8 - length) << 3; - long left = getLongLittleEndian(leftAddress) << shift; - long right = getLongLittleEndian(rightAddress) << shift; - return (left == right); + return leftWord == word(rightAddress); + } + + private static boolean equal(long leftAddress, long rightAddress, int length) { + do { + long left = UNSAFE.getLong(leftAddress); + long right = UNSAFE.getLong(rightAddress); + + if (left != right) { + return false; + } + + leftAddress += 8; + rightAddress += 8; + length -= 8; + } while (length > 0); + + return true; } } @@ -320,45 +345,89 @@ public class CalculateAverage_artsiomkorzun { // as a result a read will be split across pages, where one of them is not mapped // but for some reason it works on my machine, leaving to investigate - for (long start = position, hash = 0; position <= limit;) { - int length; // idea: royvanrijn, explanation: https://richardstartin.github.io/posts/finding-bytes - { - long word = getLongLittleEndian(position); - long match = word ^ COMMA_PATTERN; - long mask = (match - 0x0101010101010101L) & ~match & 0x8080808080808080L; + while (position <= limit) { // branchy version, credit: thomaswue + int length; + int hash; - if (mask == 0) { - hash ^= word; - position += 8; - continue; + long ptr = 0; + long word = word(position); + long separator = separator(word); + + if (separator != 0) { + length = length(separator); + word = mask(word, separator); + hash = mix(word); + ptr = aggregates.find(word, hash); + } + else { + long word0 = word; + word = word(position + 8); + separator = separator(word); + + if (separator != 0) { + length = length(separator) + 8; + word = mask(word, separator); + hash = mix(word ^ word0); + ptr = aggregates.find(word0, word, hash); } + else { + length = 16; + long h = word ^ word0; - int bit = Long.numberOfTrailingZeros(mask); - position += (bit >>> 3) + 1; // +sep - hash ^= (word << (69 - bit)); - length = (int) (position - start); + while (true) { + word = word(position + length); + separator = separator(word); + + if (separator == 0) { + length += 8; + h ^= word; + continue; + } + + length += length(separator); + word = mask(word, separator); + hash = mix(h ^ word); + break; + } + } } - int value; // idea: merykitty - { - long word = getLongLittleEndian(position); - long inverted = ~word; - int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); - long signed = (inverted << 59) >> 63; - long mask = ~(signed & 0xFF); - long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; - long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; - value = (int) ((abs ^ signed) - signed); - position += (dot >> 3) + 3; + if (ptr == 0) { + ptr = aggregates.put(position, word, length, hash); } - aggregates.add(start, length, mix(hash), value); - - start = position; - hash = 0; + position = update(ptr, position + length + 1); } } + private static long update(long ptr, long position) { + // idea: merykitty + long word = word(position); + long inverted = ~word; + int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); + long signed = (inverted << 59) >> 63; + long mask = ~(signed & 0xFF); + long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; + long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; + int value = (int) ((abs ^ signed) - signed); + + Aggregates.update(ptr, value); + return position + (dot >> 3) + 3; + } + + private static long separator(long word) { + long match = word ^ COMMA_PATTERN; + return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L); + } + + private static long mask(long word, long separator) { + return word & ((separator >>> 7) - 1) & 0x00FFFFFFFFFFFFFFL; + } + + private static int length(long separator) { + return Long.numberOfTrailingZeros(separator) >>> 3; + } + private static long next(long position) { while (UNSAFE.getByte(position++) != '\n') { // continue