branchy version (#408)
This commit is contained in:
		| @@ -20,7 +20,6 @@ import sun.misc.Unsafe; | ||||
| import java.lang.foreign.Arena; | ||||
| import java.lang.foreign.MemorySegment; | ||||
| import java.lang.reflect.Field; | ||||
| import java.nio.ByteOrder; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.file.Path; | ||||
| import java.nio.file.StandardOpenOption; | ||||
| @@ -38,11 +37,10 @@ public class CalculateAverage_artsiomkorzun { | ||||
|     private static final int SEGMENT_SIZE = 32 * 1024 * 1024; | ||||
|     private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE); | ||||
|     private static final int SEGMENT_OVERLAP = 1024; | ||||
|     private static final long COMMA_PATTERN = pattern(';'); | ||||
|     private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL; | ||||
|     private static final long DOT_BITS = 0x10101000; | ||||
|     private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); | ||||
|  | ||||
|     private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder(); | ||||
|     private static final Unsafe UNSAFE; | ||||
|  | ||||
|     static { | ||||
| @@ -95,19 +93,15 @@ public class CalculateAverage_artsiomkorzun { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static long pattern(char c) { | ||||
|         long b = c & 0xFFL; | ||||
|         return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56); | ||||
|     } | ||||
|  | ||||
|     private static long getLongLittleEndian(long address) { | ||||
|         long value = UNSAFE.getLong(address); | ||||
|  | ||||
|         if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) { | ||||
|             value = Long.reverseBytes(value); | ||||
|         } | ||||
|  | ||||
|         return value; | ||||
|     private static long word(long address) { | ||||
|         return UNSAFE.getLong(address); | ||||
|         /* | ||||
|          * if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) { | ||||
|          * value = Long.reverseBytes(value); | ||||
|          * } | ||||
|          * | ||||
|          * return value; | ||||
|          */ | ||||
|     } | ||||
|  | ||||
|     private static String text(Map<String, Aggregate> aggregates) { | ||||
| @@ -140,7 +134,7 @@ public class CalculateAverage_artsiomkorzun { | ||||
|     private static class Aggregates { | ||||
|  | ||||
|         private static final int ENTRIES = 64 * 1024; | ||||
|         private static final int SIZE = 32 * ENTRIES; | ||||
|         private static final int SIZE = 128 * ENTRIES; | ||||
|  | ||||
|         private final long pointer; | ||||
|  | ||||
| @@ -150,62 +144,82 @@ public class CalculateAverage_artsiomkorzun { | ||||
|             UNSAFE.setMemory(pointer, SIZE, (byte) 0); | ||||
|         } | ||||
|  | ||||
|         public void add(long reference, int length, int hash, int value) { | ||||
|         public long find(long word, int hash) { | ||||
|             long address = pointer + offset(hash); | ||||
|             long w = word(address + 24); | ||||
|             return (w == word) ? address : 0; | ||||
|         } | ||||
|  | ||||
|         public long find(long word1, long word2, int hash) { | ||||
|             long address = pointer + offset(hash); | ||||
|             long w1 = word(address + 24); | ||||
|             long w2 = word(address + 32); | ||||
|             return (word1 == w1) && (word2 == w2) ? address : 0; | ||||
|         } | ||||
|  | ||||
|         public long put(long reference, long word, int length, int hash) { | ||||
|             for (int offset = offset(hash);; offset = next(offset)) { | ||||
|                 long address = pointer + offset; | ||||
|                 long ref = UNSAFE.getLong(address); | ||||
|  | ||||
|                 if (ref == 0) { | ||||
|                     alloc(reference, length, hash, value, address); | ||||
|                     break; | ||||
|                 if (equal(reference, word, address + 24, length)) { | ||||
|                     return address; | ||||
|                 } | ||||
|  | ||||
|                 if (equal(ref, reference, length)) { | ||||
|                     long sum = UNSAFE.getLong(address + 16) + value; | ||||
|                     int cnt = UNSAFE.getInt(address + 24) + 1; | ||||
|                     short min = (short) Math.min(UNSAFE.getShort(address + 28), value); | ||||
|                     short max = (short) Math.max(UNSAFE.getShort(address + 30), value); | ||||
|  | ||||
|                     UNSAFE.putLong(address + 16, sum); | ||||
|                     UNSAFE.putInt(address + 24, cnt); | ||||
|                     UNSAFE.putShort(address + 28, min); | ||||
|                     UNSAFE.putShort(address + 30, max); | ||||
|                     break; | ||||
|                 int len = UNSAFE.getInt(address); | ||||
|                 if (len == 0) { | ||||
|                     alloc(reference, length, hash, address); | ||||
|                     return address; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public void merge(Aggregates rights) { | ||||
|             for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 32) { | ||||
|                 long rightAddress = rights.pointer + rightOffset; | ||||
|                 long reference = UNSAFE.getLong(rightAddress); | ||||
|         public static void update(long address, int value) { | ||||
|             long sum = UNSAFE.getLong(address + 8) + value; | ||||
|             int cnt = UNSAFE.getInt(address + 16) + 1; | ||||
|             short min = UNSAFE.getShort(address + 20); | ||||
|             short max = UNSAFE.getShort(address + 22); | ||||
|  | ||||
|                 if (reference == 0) { | ||||
|             UNSAFE.putLong(address + 8, sum); | ||||
|             UNSAFE.putInt(address + 16, cnt); | ||||
|  | ||||
|             if (value < min) { | ||||
|                 UNSAFE.putShort(address + 20, (short) value); | ||||
|             } | ||||
|  | ||||
|             if (value > max) { | ||||
|                 UNSAFE.putShort(address + 22, (short) value); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public void merge(Aggregates rights) { | ||||
|             for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 128) { | ||||
|                 long rightAddress = rights.pointer + rightOffset; | ||||
|                 int length = UNSAFE.getInt(rightAddress); | ||||
|  | ||||
|                 if (length == 0) { | ||||
|                     continue; | ||||
|                 } | ||||
|  | ||||
|                 int hash = UNSAFE.getInt(rightAddress + 8); | ||||
|                 int length = UNSAFE.getInt(rightAddress + 12); | ||||
|                 int hash = UNSAFE.getInt(rightAddress + 4); | ||||
|  | ||||
|                 for (int offset = offset(hash);; offset = next(offset)) { | ||||
|                     long address = pointer + offset; | ||||
|                     long ref = UNSAFE.getLong(address); | ||||
|                     int len = UNSAFE.getInt(address); | ||||
|  | ||||
|                     if (ref == 0) { | ||||
|                         UNSAFE.copyMemory(rightAddress, address, 32); | ||||
|                     if (len == 0) { | ||||
|                         UNSAFE.copyMemory(rightAddress, address, 24 + length); | ||||
|                         break; | ||||
|                     } | ||||
|  | ||||
|                     if (equal(ref, reference, length)) { | ||||
|                         long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16); | ||||
|                         int cnt = UNSAFE.getInt(address + 24) + UNSAFE.getInt(rightAddress + 24); | ||||
|                         short min = (short) Math.min(UNSAFE.getShort(address + 28), UNSAFE.getShort(rightAddress + 28)); | ||||
|                         short max = (short) Math.max(UNSAFE.getShort(address + 30), UNSAFE.getShort(rightAddress + 30)); | ||||
|                     if (len == length && equal(address + 24, rightAddress + 24, length)) { | ||||
|                         long sum = UNSAFE.getLong(address + 8) + UNSAFE.getLong(rightAddress + 8); | ||||
|                         int cnt = UNSAFE.getInt(address + 16) + UNSAFE.getInt(rightAddress + 16); | ||||
|                         short min = (short) Math.min(UNSAFE.getShort(address + 20), UNSAFE.getShort(rightAddress + 20)); | ||||
|                         short max = (short) Math.max(UNSAFE.getShort(address + 22), UNSAFE.getShort(rightAddress + 22)); | ||||
|  | ||||
|                         UNSAFE.putLong(address + 16, sum); | ||||
|                         UNSAFE.putInt(address + 24, cnt); | ||||
|                         UNSAFE.putShort(address + 28, min); | ||||
|                         UNSAFE.putShort(address + 30, max); | ||||
|                         UNSAFE.putLong(address + 8, sum); | ||||
|                         UNSAFE.putInt(address + 16, cnt); | ||||
|                         UNSAFE.putShort(address + 20, min); | ||||
|                         UNSAFE.putShort(address + 22, max); | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
| @@ -215,20 +229,19 @@ public class CalculateAverage_artsiomkorzun { | ||||
|         public Map<String, Aggregate> aggregate() { | ||||
|             TreeMap<String, Aggregate> set = new TreeMap<>(); | ||||
|  | ||||
|             for (int offset = 0; offset < SIZE; offset += 32) { | ||||
|             for (int offset = 0; offset < SIZE; offset += 128) { | ||||
|                 long address = pointer + offset; | ||||
|                 long ref = UNSAFE.getLong(address); | ||||
|                 int length = UNSAFE.getInt(address); | ||||
|  | ||||
|                 if (ref != 0) { | ||||
|                     int length = UNSAFE.getInt(address + 12) - 1; | ||||
|                 if (length != 0) { | ||||
|                     byte[] array = new byte[length]; | ||||
|                     UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); | ||||
|                     UNSAFE.copyMemory(null, address + 24, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); | ||||
|                     String key = new String(array); | ||||
|  | ||||
|                     long sum = UNSAFE.getLong(address + 16); | ||||
|                     int cnt = UNSAFE.getInt(address + 24); | ||||
|                     short min = UNSAFE.getShort(address + 28); | ||||
|                     short max = UNSAFE.getShort(address + 30); | ||||
|                     long sum = UNSAFE.getLong(address + 8); | ||||
|                     int cnt = UNSAFE.getInt(address + 16); | ||||
|                     short min = UNSAFE.getShort(address + 20); | ||||
|                     short max = UNSAFE.getShort(address + 22); | ||||
|  | ||||
|                     Aggregate aggregate = new Aggregate(min, max, sum, cnt); | ||||
|                     set.put(key, aggregate); | ||||
| @@ -238,26 +251,24 @@ public class CalculateAverage_artsiomkorzun { | ||||
|             return set; | ||||
|         } | ||||
|  | ||||
|         private static void alloc(long reference, int length, int hash, int value, long address) { | ||||
|             UNSAFE.putLong(address, reference); | ||||
|             UNSAFE.putInt(address + 8, hash); | ||||
|             UNSAFE.putInt(address + 12, length); | ||||
|             UNSAFE.putLong(address + 16, value); | ||||
|             UNSAFE.putInt(address + 24, 1); | ||||
|             UNSAFE.putShort(address + 28, (short) value); | ||||
|             UNSAFE.putShort(address + 30, (short) value); | ||||
|         private static void alloc(long reference, int length, int hash, long address) { | ||||
|             UNSAFE.putInt(address, length); | ||||
|             UNSAFE.putInt(address + 4, hash); | ||||
|             UNSAFE.putShort(address + 20, Short.MAX_VALUE); | ||||
|             UNSAFE.putShort(address + 22, Short.MIN_VALUE); | ||||
|             UNSAFE.copyMemory(reference, address + 24, length); | ||||
|         } | ||||
|  | ||||
|         private static int offset(int hash) { | ||||
|             return ((hash) & (ENTRIES - 1)) << 5; | ||||
|             return ((hash) & (ENTRIES - 1)) << 7; | ||||
|         } | ||||
|  | ||||
|         private static int next(int prev) { | ||||
|             return (prev + 32) & (SIZE - 1); | ||||
|             return (prev + 128) & (SIZE - 1); | ||||
|         } | ||||
|  | ||||
|         private static boolean equal(long leftAddress, long rightAddress, int length) { | ||||
|             while (length > 8) { | ||||
|         private static boolean equal(long leftAddress, long leftWord, long rightAddress, int length) { | ||||
|             while (length >= 8) { | ||||
|                 long left = UNSAFE.getLong(leftAddress); | ||||
|                 long right = UNSAFE.getLong(rightAddress); | ||||
|  | ||||
| @@ -270,10 +281,24 @@ public class CalculateAverage_artsiomkorzun { | ||||
|                 length -= 8; | ||||
|             } | ||||
|  | ||||
|             int shift = (8 - length) << 3; | ||||
|             long left = getLongLittleEndian(leftAddress) << shift; | ||||
|             long right = getLongLittleEndian(rightAddress) << shift; | ||||
|             return (left == right); | ||||
|             return leftWord == word(rightAddress); | ||||
|         } | ||||
|  | ||||
|         private static boolean equal(long leftAddress, long rightAddress, int length) { | ||||
|             do { | ||||
|                 long left = UNSAFE.getLong(leftAddress); | ||||
|                 long right = UNSAFE.getLong(rightAddress); | ||||
|  | ||||
|                 if (left != right) { | ||||
|                     return false; | ||||
|                 } | ||||
|  | ||||
|                 leftAddress += 8; | ||||
|                 rightAddress += 8; | ||||
|                 length -= 8; | ||||
|             } while (length > 0); | ||||
|  | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -320,45 +345,89 @@ public class CalculateAverage_artsiomkorzun { | ||||
|             // as a result a read will be split across pages, where one of them is not mapped | ||||
|             // but for some reason it works on my machine, leaving to investigate | ||||
|  | ||||
|             for (long start = position, hash = 0; position <= limit;) { | ||||
|                 int length; // idea: royvanrijn, explanation: https://richardstartin.github.io/posts/finding-bytes | ||||
|                 { | ||||
|                     long word = getLongLittleEndian(position); | ||||
|                     long match = word ^ COMMA_PATTERN; | ||||
|                     long mask = (match - 0x0101010101010101L) & ~match & 0x8080808080808080L; | ||||
|             while (position <= limit) { // branchy version, credit: thomaswue | ||||
|                 int length; | ||||
|                 int hash; | ||||
|  | ||||
|                     if (mask == 0) { | ||||
|                         hash ^= word; | ||||
|                         position += 8; | ||||
|                         continue; | ||||
|                 long ptr = 0; | ||||
|                 long word = word(position); | ||||
|                 long separator = separator(word); | ||||
|  | ||||
|                 if (separator != 0) { | ||||
|                     length = length(separator); | ||||
|                     word = mask(word, separator); | ||||
|                     hash = mix(word); | ||||
|                     ptr = aggregates.find(word, hash); | ||||
|                 } | ||||
|                 else { | ||||
|                     long word0 = word; | ||||
|                     word = word(position + 8); | ||||
|                     separator = separator(word); | ||||
|  | ||||
|                     if (separator != 0) { | ||||
|                         length = length(separator) + 8; | ||||
|                         word = mask(word, separator); | ||||
|                         hash = mix(word ^ word0); | ||||
|                         ptr = aggregates.find(word0, word, hash); | ||||
|                     } | ||||
|                     else { | ||||
|                         length = 16; | ||||
|                         long h = word ^ word0; | ||||
|  | ||||
|                     int bit = Long.numberOfTrailingZeros(mask); | ||||
|                     position += (bit >>> 3) + 1; // +sep | ||||
|                     hash ^= (word << (69 - bit)); | ||||
|                     length = (int) (position - start); | ||||
|                         while (true) { | ||||
|                             word = word(position + length); | ||||
|                             separator = separator(word); | ||||
|  | ||||
|                             if (separator == 0) { | ||||
|                                 length += 8; | ||||
|                                 h ^= word; | ||||
|                                 continue; | ||||
|                             } | ||||
|  | ||||
|                             length += length(separator); | ||||
|                             word = mask(word, separator); | ||||
|                             hash = mix(h ^ word); | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 int value; // idea: merykitty | ||||
|                 { | ||||
|                     long word = getLongLittleEndian(position); | ||||
|                     long inverted = ~word; | ||||
|                     int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); | ||||
|                     long signed = (inverted << 59) >> 63; | ||||
|                     long mask = ~(signed & 0xFF); | ||||
|                     long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; | ||||
|                     long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; | ||||
|                     value = (int) ((abs ^ signed) - signed); | ||||
|                     position += (dot >> 3) + 3; | ||||
|                 if (ptr == 0) { | ||||
|                     ptr = aggregates.put(position, word, length, hash); | ||||
|                 } | ||||
|  | ||||
|                 aggregates.add(start, length, mix(hash), value); | ||||
|  | ||||
|                 start = position; | ||||
|                 hash = 0; | ||||
|                 position = update(ptr, position + length + 1); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private static long update(long ptr, long position) { | ||||
|             // idea: merykitty | ||||
|             long word = word(position); | ||||
|             long inverted = ~word; | ||||
|             int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); | ||||
|             long signed = (inverted << 59) >> 63; | ||||
|             long mask = ~(signed & 0xFF); | ||||
|             long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; | ||||
|             long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; | ||||
|             int value = (int) ((abs ^ signed) - signed); | ||||
|  | ||||
|             Aggregates.update(ptr, value); | ||||
|             return position + (dot >> 3) + 3; | ||||
|         } | ||||
|  | ||||
|         private static long separator(long word) { | ||||
|             long match = word ^ COMMA_PATTERN; | ||||
|             return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L); | ||||
|         } | ||||
|  | ||||
|         private static long mask(long word, long separator) { | ||||
|             return word & ((separator >>> 7) - 1) & 0x00FFFFFFFFFFFFFFL; | ||||
|         } | ||||
|  | ||||
|         private static int length(long separator) { | ||||
|             return Long.numberOfTrailingZeros(separator) >>> 3; | ||||
|         } | ||||
|  | ||||
|         private static long next(long position) { | ||||
|             while (UNSAFE.getByte(position++) != '\n') { | ||||
|                 // continue | ||||
|   | ||||
		Reference in New Issue
	
	Block a user