improved artsiomkorzun solution (#321)
This commit is contained in:
		| @@ -15,5 +15,5 @@ | ||||
| #  limitations under the License. | ||||
| # | ||||
|  | ||||
| JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC" | ||||
| JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation" | ||||
| java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_artsiomkorzun | ||||
|   | ||||
| @@ -35,7 +35,7 @@ public class CalculateAverage_artsiomkorzun { | ||||
|     private static final MemorySegment MAPPED_FILE = map(FILE); | ||||
|  | ||||
|     private static final int PARALLELISM = Runtime.getRuntime().availableProcessors(); | ||||
|     private static final int SEGMENT_SIZE = 16 * 1024 * 1024; | ||||
|     private static final int SEGMENT_SIZE = 32 * 1024 * 1024; | ||||
|     private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE); | ||||
|     private static final int SEGMENT_OVERLAP = 1024; | ||||
|     private static final long COMMA_PATTERN = pattern(';'); | ||||
| @@ -100,16 +100,6 @@ public class CalculateAverage_artsiomkorzun { | ||||
|         return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56); | ||||
|     } | ||||
|  | ||||
|     private static long getLongBigEndian(long address) { | ||||
|         long value = UNSAFE.getLong(address); | ||||
|  | ||||
|         if (BYTE_ORDER == ByteOrder.LITTLE_ENDIAN) { | ||||
|             value = Long.reverseBytes(value); | ||||
|         } | ||||
|  | ||||
|         return value; | ||||
|     } | ||||
|  | ||||
|     private static long getLongLittleEndian(long address) { | ||||
|         long value = UNSAFE.getLong(address); | ||||
|  | ||||
| @@ -144,98 +134,80 @@ public class CalculateAverage_artsiomkorzun { | ||||
|         return Math.round(v) / 10.0; | ||||
|     } | ||||
|  | ||||
|     private static class Row { | ||||
|         long address; | ||||
|         int length; | ||||
|         int hash; | ||||
|         int value; | ||||
|     } | ||||
|  | ||||
|     private record Aggregate(int min, int max, long sum, int cnt) { | ||||
|     } | ||||
|  | ||||
|     private static class Aggregates { | ||||
|  | ||||
|         private static final int SIZE = 16 * 1024; | ||||
|         private static final int ENTRIES = 64 * 1024; | ||||
|         private static final int SIZE = 32 * ENTRIES; | ||||
|  | ||||
|         private final long pointer; | ||||
|  | ||||
|         public Aggregates() { | ||||
|             int size = 32 * SIZE; | ||||
|             long address = UNSAFE.allocateMemory(size + 8096); | ||||
|             long address = UNSAFE.allocateMemory(SIZE + 8096); | ||||
|             pointer = (address + 4095) & (~4095); | ||||
|             UNSAFE.setMemory(pointer, size, (byte) 0); | ||||
|  | ||||
|             long word = pack(Short.MAX_VALUE, Short.MIN_VALUE, 0); | ||||
|             for (int i = 0; i < SIZE; i++) { | ||||
|                 long entry = pointer + 32 * i; | ||||
|                 UNSAFE.putLong(entry + 24, word); | ||||
|             } | ||||
|             UNSAFE.setMemory(pointer, SIZE, (byte) 0); | ||||
|         } | ||||
|  | ||||
|         public void add(Row row) { | ||||
|             long index = index(row.hash); | ||||
|             long header = ((long) row.hash << 32) | (row.length); | ||||
|         public void add(long reference, int length, int hash, int value) { | ||||
|             for (int offset = offset(hash);; offset = next(offset)) { | ||||
|                 long address = pointer + offset; | ||||
|                 long ref = UNSAFE.getLong(address); | ||||
|  | ||||
|             while (true) { | ||||
|                 long address = pointer + (index << 5); | ||||
|                 long head = UNSAFE.getLong(address); | ||||
|                 long ref = UNSAFE.getLong(address + 8); | ||||
|                 boolean isHit = (head == 0) || (head == header && equal(ref, row.address, row.length)); | ||||
|  | ||||
|                 if (isHit) { | ||||
|                     long sum = UNSAFE.getLong(address + 16) + row.value; | ||||
|                     long word = UNSAFE.getLong(address + 24); | ||||
|                     int min = Math.min(min(word), row.value); | ||||
|                     int max = Math.max(max(word), row.value); | ||||
|                     int cnt = cnt(word) + 1; | ||||
|  | ||||
|                     UNSAFE.putLong(address, header); | ||||
|                     UNSAFE.putLong(address + 8, row.address); | ||||
|                     UNSAFE.putLong(address + 16, sum); | ||||
|                     UNSAFE.putLong(address + 24, pack(min, max, cnt)); | ||||
|                 if (ref == 0) { | ||||
|                     alloc(reference, length, hash, value, address); | ||||
|                     break; | ||||
|                 } | ||||
|  | ||||
|                 index = (index + 1) & (SIZE - 1); | ||||
|                 if (equal(ref, reference, length)) { | ||||
|                     long sum = UNSAFE.getLong(address + 16) + value; | ||||
|                     int cnt = UNSAFE.getInt(address + 24) + 1; | ||||
|                     short min = (short) Math.min(UNSAFE.getShort(address + 28), value); | ||||
|                     short max = (short) Math.max(UNSAFE.getShort(address + 30), value); | ||||
|  | ||||
|                     UNSAFE.putLong(address + 16, sum); | ||||
|                     UNSAFE.putInt(address + 24, cnt); | ||||
|                     UNSAFE.putShort(address + 28, min); | ||||
|                     UNSAFE.putShort(address + 30, max); | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public void merge(Aggregates rights) { | ||||
|             for (int rightIndex = 0; rightIndex < SIZE; rightIndex++) { | ||||
|                 long rightAddress = rights.pointer + (rightIndex << 5); | ||||
|                 long header = UNSAFE.getLong(rightAddress); | ||||
|                 long reference = UNSAFE.getLong(rightAddress + 8); | ||||
|             for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 32) { | ||||
|                 long rightAddress = rights.pointer + rightOffset; | ||||
|                 long reference = UNSAFE.getLong(rightAddress); | ||||
|  | ||||
|                 if (header == 0) { | ||||
|                 if (reference == 0) { | ||||
|                     continue; | ||||
|                 } | ||||
|  | ||||
|                 int hash = (int) (header >>> 32); | ||||
|                 int length = (int) (header); | ||||
|                 long index = index(hash); | ||||
|                 int hash = UNSAFE.getInt(rightAddress + 8); | ||||
|                 int length = UNSAFE.getInt(rightAddress + 12); | ||||
|  | ||||
|                 while (true) { | ||||
|                     long address = pointer + (index << 5); | ||||
|                     long head = UNSAFE.getLong(address); | ||||
|                     long ref = UNSAFE.getLong(address + 8); | ||||
|                     boolean isHit = (head == 0) || (head == header && equal(ref, reference, length)); | ||||
|                 for (int offset = offset(hash);; offset = next(offset)) { | ||||
|                     long address = pointer + offset; | ||||
|                     long ref = UNSAFE.getLong(address); | ||||
|  | ||||
|                     if (isHit) { | ||||
|                         long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16); | ||||
|                         long left = UNSAFE.getLong(address + 24); | ||||
|                         long right = UNSAFE.getLong(rightAddress + 24); | ||||
|                         int min = Math.min(min(left), min(right)); | ||||
|                         int max = Math.max(max(left), max(right)); | ||||
|                         int cnt = cnt(left) + cnt(right); | ||||
|  | ||||
|                         UNSAFE.putLong(address, header); | ||||
|                         UNSAFE.putLong(address + 8, reference); | ||||
|                         UNSAFE.putLong(address + 16, sum); | ||||
|                         UNSAFE.putLong(address + 24, pack(min, max, cnt)); | ||||
|                     if (ref == 0) { | ||||
|                         UNSAFE.copyMemory(rightAddress, address, 32); | ||||
|                         break; | ||||
|                     } | ||||
|  | ||||
|                     index = (index + 1) & (SIZE - 1); | ||||
|                     if (equal(ref, reference, length)) { | ||||
|                         long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16); | ||||
|                         int cnt = UNSAFE.getInt(address + 24) + UNSAFE.getInt(rightAddress + 24); | ||||
|                         short min = (short) Math.min(UNSAFE.getShort(address + 28), UNSAFE.getShort(rightAddress + 28)); | ||||
|                         short max = (short) Math.max(UNSAFE.getShort(address + 30), UNSAFE.getShort(rightAddress + 30)); | ||||
|  | ||||
|                         UNSAFE.putLong(address + 16, sum); | ||||
|                         UNSAFE.putInt(address + 24, cnt); | ||||
|                         UNSAFE.putShort(address + 28, min); | ||||
|                         UNSAFE.putShort(address + 30, max); | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| @@ -243,68 +215,64 @@ public class CalculateAverage_artsiomkorzun { | ||||
|         public Map<String, Aggregate> aggregate() { | ||||
|             TreeMap<String, Aggregate> set = new TreeMap<>(); | ||||
|  | ||||
|             for (int index = 0; index < SIZE; index++) { | ||||
|                 long address = pointer + (index << 5); | ||||
|                 long head = UNSAFE.getLong(address); | ||||
|                 long ref = UNSAFE.getLong(address + 8); | ||||
|             for (int offset = 0; offset < SIZE; offset += 32) { | ||||
|                 long address = pointer + offset; | ||||
|                 long ref = UNSAFE.getLong(address); | ||||
|  | ||||
|                 if (head == 0) { | ||||
|                     continue; | ||||
|                 if (ref != 0) { | ||||
|                     int length = UNSAFE.getInt(address + 12) - 1; | ||||
|                     byte[] array = new byte[length]; | ||||
|                     UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); | ||||
|                     String key = new String(array); | ||||
|  | ||||
|                     long sum = UNSAFE.getLong(address + 16); | ||||
|                     int cnt = UNSAFE.getInt(address + 24); | ||||
|                     short min = UNSAFE.getShort(address + 28); | ||||
|                     short max = UNSAFE.getShort(address + 30); | ||||
|  | ||||
|                     Aggregate aggregate = new Aggregate(min, max, sum, cnt); | ||||
|                     set.put(key, aggregate); | ||||
|                 } | ||||
|  | ||||
|                 int length = (int) (head); | ||||
|                 byte[] array = new byte[length]; | ||||
|                 UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); | ||||
|                 String key = new String(array); | ||||
|  | ||||
|                 long sum = UNSAFE.getLong(address + 16); | ||||
|                 long word = UNSAFE.getLong(address + 24); | ||||
|  | ||||
|                 Aggregate aggregate = new Aggregate(min(word), max(word), sum, cnt(word)); | ||||
|                 set.put(key, aggregate); | ||||
|             } | ||||
|  | ||||
|             return set; | ||||
|         } | ||||
|  | ||||
|         private static long pack(int min, int max, int cnt) { | ||||
|             return ((long) min << 48) | (((long) max & 0xFFFF) << 32) | cnt; | ||||
|         private static void alloc(long reference, int length, int hash, int value, long address) { | ||||
|             UNSAFE.putLong(address, reference); | ||||
|             UNSAFE.putInt(address + 8, hash); | ||||
|             UNSAFE.putInt(address + 12, length); | ||||
|             UNSAFE.putLong(address + 16, value); | ||||
|             UNSAFE.putInt(address + 24, 1); | ||||
|             UNSAFE.putShort(address + 28, (short) value); | ||||
|             UNSAFE.putShort(address + 30, (short) value); | ||||
|         } | ||||
|  | ||||
|         private static int cnt(long word) { | ||||
|             return (int) word; | ||||
|         private static int offset(int hash) { | ||||
|             return ((hash) & (ENTRIES - 1)) << 5; | ||||
|         } | ||||
|  | ||||
|         private static int max(long word) { | ||||
|             return (short) (word >>> 32); | ||||
|         } | ||||
|  | ||||
|         private static int min(long word) { | ||||
|             return (short) (word >>> 48); | ||||
|         } | ||||
|  | ||||
|         private static long index(int hash) { | ||||
|             return (hash ^ (hash >> 16)) & (SIZE - 1); | ||||
|         private static int next(int prev) { | ||||
|             return (prev + 32) & (SIZE - 1); | ||||
|         } | ||||
|  | ||||
|         private static boolean equal(long leftAddress, long rightAddress, int length) { | ||||
|             int index = 0; | ||||
|  | ||||
|             while (length > 8) { | ||||
|                 long left = UNSAFE.getLong(leftAddress + index); | ||||
|                 long right = UNSAFE.getLong(rightAddress + index); | ||||
|                 long left = UNSAFE.getLong(leftAddress); | ||||
|                 long right = UNSAFE.getLong(rightAddress); | ||||
|  | ||||
|                 if (left != right) { | ||||
|                     return false; | ||||
|                 } | ||||
|  | ||||
|                 leftAddress += 8; | ||||
|                 rightAddress += 8; | ||||
|                 length -= 8; | ||||
|                 index += 8; | ||||
|             } | ||||
|  | ||||
|             int shift = 64 - (length << 3); | ||||
|             long left = getLongBigEndian(leftAddress + index) >>> shift; | ||||
|             long right = getLongBigEndian(rightAddress + index) >>> shift; | ||||
|             int shift = (8 - length) << 3; | ||||
|             long left = getLongLittleEndian(leftAddress) << shift; | ||||
|             long right = getLongLittleEndian(rightAddress) << shift; | ||||
|             return (left == right); | ||||
|         } | ||||
|     } | ||||
| @@ -323,10 +291,18 @@ public class CalculateAverage_artsiomkorzun { | ||||
|         @Override | ||||
|         public void run() { | ||||
|             Aggregates aggregates = new Aggregates(); | ||||
|             Row row = new Row(); | ||||
|  | ||||
|             for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) { | ||||
|                 aggregate(aggregates, row, segment); | ||||
|                 long position = (long) SEGMENT_SIZE * segment; | ||||
|                 int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, MAPPED_FILE.byteSize() - position); | ||||
|                 long address = MAPPED_FILE.address() + position; | ||||
|                 long limit = address + Math.min(SEGMENT_SIZE, size - 1); | ||||
|  | ||||
|                 if (segment > 0) { | ||||
|                     address = next(address); | ||||
|                 } | ||||
|  | ||||
|                 aggregate(aggregates, address, limit); | ||||
|             } | ||||
|  | ||||
|             while (!result.compareAndSet(null, aggregates)) { | ||||
| @@ -338,75 +314,62 @@ public class CalculateAverage_artsiomkorzun { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private static void aggregate(Aggregates aggregates, Row row, int segment) { | ||||
|             long position = (long) SEGMENT_SIZE * segment; | ||||
|             int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, MAPPED_FILE.byteSize() - position); | ||||
|             long address = MAPPED_FILE.address() + position; | ||||
|             long limit = address + Math.min(SEGMENT_SIZE, size - 1); | ||||
|         private static void aggregate(Aggregates aggregates, long position, long limit) { | ||||
|             // this parsing can produce seg fault at page boundaries | ||||
|             // e.g. file size is 4096 and the last entry is X=0.0, which is less than 8 bytes | ||||
|             // as a result a read will be split across pages, where one of them is not mapped | ||||
|             // but for some reason it works on my machine, leaving to investigate | ||||
|  | ||||
|             if (segment > 0) { | ||||
|                 address = next(address); | ||||
|             } | ||||
|             for (long start = position, hash = 0; position <= limit;) { | ||||
|                 int length; // idea: royvanrijn, explanation: https://richardstartin.github.io/posts/finding-bytes | ||||
|                 { | ||||
|                     long word = getLongLittleEndian(position); | ||||
|                     long match = word ^ COMMA_PATTERN; | ||||
|                     long mask = (match - 0x0101010101010101L) & ~match & 0x8080808080808080L; | ||||
|  | ||||
|             while (address <= limit) { | ||||
|                 // this parsing can produce seg fault at page boundaries | ||||
|                 // e.g. file size is 4096 and the last entry is X=0.0, which is less than 8 bytes | ||||
|                 // as a result a read will be split across pages, where one of them is not mapped | ||||
|                 // but for some reason it works on my machine, leaving to investigate | ||||
|                 address = parseKey(address, row); | ||||
|                 address = parseValue(address, row); | ||||
|                 aggregates.add(row); | ||||
|             } | ||||
|         } | ||||
|                     if (mask == 0) { | ||||
|                         hash ^= word; | ||||
|                         position += 8; | ||||
|                         continue; | ||||
|                     } | ||||
|  | ||||
|         private static long next(long address) { | ||||
|             while (UNSAFE.getByte(address++) != '\n') { | ||||
|                 // continue | ||||
|             } | ||||
|             return address; | ||||
|         } | ||||
|  | ||||
|         // idea: royvanrijn | ||||
|         // explanation: https://richardstartin.github.io/posts/finding-bytes | ||||
|         private static long parseKey(long address, Row row) { | ||||
|             int length = 0; | ||||
|             long hash = 0; | ||||
|             long word; | ||||
|  | ||||
|             while (true) { | ||||
|                 word = getLongLittleEndian(address + length); | ||||
|                 long match = word ^ COMMA_PATTERN; | ||||
|                 long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L; | ||||
|  | ||||
|                 if (mask == 0) { | ||||
|                     hash = 71 * hash + word; | ||||
|                     length += 8; | ||||
|                     continue; | ||||
|                     int bit = Long.numberOfTrailingZeros(mask); | ||||
|                     position += (bit >>> 3) + 1; // +sep | ||||
|                     hash ^= (word << (69 - bit)); | ||||
|                     length = (int) (position - start); | ||||
|                 } | ||||
|  | ||||
|                 int bit = Long.numberOfTrailingZeros(mask); | ||||
|                 length += (bit >>> 3); | ||||
|                 hash = 71 * hash + (word & (0x00FFFFFFFFFFFFFFL >>> (63 - bit))); | ||||
|                 int value; // idea: merykitty | ||||
|                 { | ||||
|                     long word = getLongLittleEndian(position); | ||||
|                     long inverted = ~word; | ||||
|                     int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); | ||||
|                     long signed = (inverted << 59) >> 63; | ||||
|                     long mask = ~(signed & 0xFF); | ||||
|                     long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; | ||||
|                     long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; | ||||
|                     value = (int) ((abs ^ signed) - signed); | ||||
|                     position += (dot >> 3) + 3; | ||||
|                 } | ||||
|  | ||||
|                 row.address = address; | ||||
|                 row.length = length; | ||||
|                 row.hash = Long.hashCode(hash); | ||||
|                 aggregates.add(start, length, mix(hash), value); | ||||
|  | ||||
|                 return address + length + 1; | ||||
|                 start = position; | ||||
|                 hash = 0; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // idea: merykitty | ||||
|         private static long parseValue(long address, Row row) { | ||||
|             long word = getLongLittleEndian(address); | ||||
|             long inverted = ~word; | ||||
|             int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); | ||||
|             long signed = (inverted << 59) >> 63; | ||||
|             long mask = ~(signed & 0xFF); | ||||
|             long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; | ||||
|             long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; | ||||
|             row.value = (int) ((abs ^ signed) - signed); | ||||
|             return address + (dot >> 3) + 3; | ||||
|         private static long next(long position) { | ||||
|             while (UNSAFE.getByte(position++) != '\n') { | ||||
|                 // continue | ||||
|             } | ||||
|             return position; | ||||
|         } | ||||
|  | ||||
|         private static int mix(long x) { | ||||
|             long h = x * -7046029254386353131L; | ||||
|             h ^= h >>> 32; | ||||
|             return (int) (h ^ h >>> 16); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user