Optimized with less constructor args + low collision mixer (#420)
* use all CPUs * use graal * optimized with less constructor arg * optimized with low collision mixer
This commit is contained in:
		| @@ -28,6 +28,8 @@ import java.util.TreeMap; | ||||
| import sun.misc.Unsafe; | ||||
|  | ||||
| public class CalculateAverage_abeobk { | ||||
|     private static final boolean SHOW_COLLISIONS = false; | ||||
|  | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|     private static final int BUCKET_SIZE = 1 << 16; | ||||
|     private static final int BUCKET_MASK = BUCKET_SIZE - 1; | ||||
| @@ -55,69 +57,55 @@ public class CalculateAverage_abeobk { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // stat | ||||
|     private static class Stat { | ||||
|         private int min; | ||||
|         private int max; | ||||
|         private long sum; | ||||
|         private int count; | ||||
|     static class Node { | ||||
|         long addr; | ||||
|         long tail; | ||||
|         int min, max; | ||||
|         int count; | ||||
|         long sum; | ||||
|  | ||||
|         Stat(int v) { | ||||
|             sum = min = max = v; | ||||
|             count = 1; | ||||
|         } | ||||
|  | ||||
|         void add(int val) { | ||||
|             min = Math.min(val, min); | ||||
|             max = Math.max(val, max); | ||||
|             sum += val; | ||||
|             count++; | ||||
|         } | ||||
|  | ||||
|         void merge(Stat other) { | ||||
|             min = Math.min(other.min, min); | ||||
|             max = Math.max(other.max, max); | ||||
|             sum += other.sum; | ||||
|             count += other.count; | ||||
|         String key() { | ||||
|             byte[] sbuf = new byte[MAX_STR_LEN]; | ||||
|             int keylen = (int) (tail >>> 56); | ||||
|             UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen); | ||||
|             return new String(sbuf, 0, keylen, StandardCharsets.UTF_8); | ||||
|         } | ||||
|  | ||||
|         public String toString() { | ||||
|             return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     static class Node { | ||||
|         long addr; | ||||
|         int keylen; | ||||
|         int hash; | ||||
|         long[] buf = new long[13]; | ||||
|         Stat stat; | ||||
|  | ||||
|         String key() { | ||||
|             byte[] buf = new byte[MAX_STR_LEN]; | ||||
|             UNSAFE.copyMemory(null, addr, buf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen); | ||||
|             return new String(buf, 0, keylen, StandardCharsets.UTF_8); | ||||
|         } | ||||
|  | ||||
|         Node(long a, int kl, int h, int v, long[] b) { | ||||
|             stat = new Stat(v); | ||||
|         Node(long a, long t, int val) { | ||||
|             addr = a; | ||||
|             keylen = kl; | ||||
|             hash = h; | ||||
|             System.arraycopy(b, 0, buf, 0, Math.ceilDiv(kl, 8)); | ||||
|             tail = t; | ||||
|             sum = min = max = val; | ||||
|             count = 1; | ||||
|         } | ||||
|  | ||||
|         boolean contentEquals(final long[] other_buf) { | ||||
|             int k = keylen / 8; | ||||
|             int r = keylen % 8; | ||||
|             // Since the city name is most likely shorter than 16 characters | ||||
|             // this should be faster than typical conditional checks | ||||
|             long sum = 0; | ||||
|             for (int i = 0; i < k; i++) { | ||||
|                 sum += buf[i] ^ other_buf[i]; | ||||
|         void add(int val) { | ||||
|             min = Math.min(min, val); | ||||
|             max = Math.max(max, val); | ||||
|             sum += val; | ||||
|             count++; | ||||
|         } | ||||
|  | ||||
|         void merge(Node other) { | ||||
|             min = Math.min(min, other.min); | ||||
|             max = Math.max(max, other.max); | ||||
|             sum += other.sum; | ||||
|             count += other.count; | ||||
|         } | ||||
|  | ||||
|         boolean contentEquals(long other_addr, long other_tail) { | ||||
|             if (tail != other_tail) // compare tail & length at the same time | ||||
|                 return false; | ||||
|             long my_addr = addr; | ||||
|             int nl = (int) (tail >> 59); | ||||
|             for (int i = 0; i < nl; i++, my_addr += 8, other_addr += 8) { | ||||
|                 if (UNSAFE.getLong(my_addr) != UNSAFE.getLong(other_addr)) | ||||
|                     return false; | ||||
|             } | ||||
|             sum += (buf[k] ^ other_buf[k]) & HASH_MASKS[r]; | ||||
|             return sum == 0; | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -135,55 +123,83 @@ public class CalculateAverage_abeobk { | ||||
|         return ptrs; | ||||
|     } | ||||
|  | ||||
|     static final long getSemiPosCode(final long word) { | ||||
|         long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;; | ||||
|         return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L); | ||||
|     } | ||||
|  | ||||
|     // very low collision mixer | ||||
|     // idea from https://github.com/Cyan4973/xxHash/tree/dev | ||||
|     // zero collision on test data | ||||
|     static final int xxh32(long hash) { | ||||
|         final int p1 = 0x85EBCA77; // prime | ||||
|         final int p2 = 0xC2B2AE3D; // prime | ||||
|         int low = (int) hash; | ||||
|         int high = (int) (hash >>> 32); | ||||
|         low ^= low >> 15; | ||||
|         low *= p1; | ||||
|         high ^= high >> 13; | ||||
|         high *= p2; | ||||
|         var h = low ^ high; | ||||
|         return h; | ||||
|     } | ||||
|  | ||||
|     public static void main(String[] args) throws InterruptedException, IOException { | ||||
|         int cpu_cnt = Runtime.getRuntime().availableProcessors(); | ||||
|         try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { | ||||
|             long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); | ||||
|             long file_size = file.size(); | ||||
|             long end_addr = start_addr + file_size; | ||||
|  | ||||
|             // only use all cpus on large file | ||||
|             int cpu_cnt = file_size < 1e6 ? 1 : Runtime.getRuntime().availableProcessors(); | ||||
|             long chunk_size = Math.ceilDiv(file_size, cpu_cnt); | ||||
|  | ||||
|             // processing | ||||
|             var threads = new Thread[cpu_cnt]; | ||||
|             var maps = new Node[cpu_cnt][]; | ||||
|             var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); | ||||
|             int[] cls = new int[cpu_cnt]; | ||||
|  | ||||
|             for (int i = 0; i < cpu_cnt; i++) { | ||||
|                 int thread_id = i; | ||||
|                 long start = ptrs[i]; | ||||
|                 long end = ptrs[i + 1]; | ||||
|                 maps[i] = new Node[BUCKET_SIZE + 16]; // extra space for collisions | ||||
|                 maps[i] = new Node[BUCKET_SIZE + 10000]; // extra space for collisions | ||||
|  | ||||
|                 (threads[i] = new Thread(() -> { | ||||
|                     long addr = start; | ||||
|                     var map = maps[thread_id]; | ||||
|                     long[] buf = new long[13]; | ||||
|                     // parse loop | ||||
|                     while (addr < end) { | ||||
|                         int idx = 0; | ||||
|                         long hash = 0; | ||||
|                         long word = 0; | ||||
|                         long row_addr = addr; | ||||
|                         int semi_pos = 8; | ||||
|                         while (semi_pos == 8) { | ||||
|                         word = UNSAFE.getLong(addr); | ||||
|                         long semipos_code = getSemiPosCode(word); | ||||
|  | ||||
|                         while (semipos_code == 0) { | ||||
|                             hash ^= word; | ||||
|                             addr += 8; | ||||
|                             word = UNSAFE.getLong(addr); | ||||
|                             buf[idx++] = word; | ||||
|                             // idea from thomaswue & royvanrijn | ||||
|                             long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;; | ||||
|                             long semipos_code = (xor_semi - 0x0101010101010101L) & ~xor_semi & 0x8080808080808080L; | ||||
|                             semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; | ||||
|                             addr += semi_pos; | ||||
|                             hash ^= word & HASH_MASKS[semi_pos]; | ||||
|                             semipos_code = getSemiPosCode(word); | ||||
|                         } | ||||
|  | ||||
|                         int hash32 = (int) (hash ^ (hash >>> 31)); | ||||
|                         int keylen = (int) (addr - row_addr); | ||||
|                         semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; | ||||
|                         long tail = word & HASH_MASKS[semi_pos]; | ||||
|                         hash ^= tail; | ||||
|                         addr += semi_pos; | ||||
|  | ||||
|                         int hash32 = xxh32(hash); | ||||
|                         long keylen = (addr - row_addr); | ||||
|                         tail = tail | (keylen << 56); | ||||
|  | ||||
|                         addr++; | ||||
|  | ||||
|                         // great idea from merykitty (Quan Anh Mai) | ||||
|                         long num_word = UNSAFE.getLong(++addr); | ||||
|                         long num_word = UNSAFE.getLong(addr); | ||||
|                         int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); | ||||
|                         addr += (dot_pos >>> 3) + 3; | ||||
|  | ||||
|                         int shift = 28 - dot_pos; | ||||
|                         long signed = (~num_word << 59) >> 63; | ||||
|                         long dsmask = ~(signed & 0xFF); | ||||
| @@ -195,14 +211,16 @@ public class CalculateAverage_abeobk { | ||||
|                         while (true) { | ||||
|                             var node = map[bucket]; | ||||
|                             if (node == null) { | ||||
|                                 map[bucket] = new Node(row_addr, keylen, hash32, val, buf); | ||||
|                                 map[bucket] = new Node(row_addr, tail, val); | ||||
|                                 break; | ||||
|                             } | ||||
|                             if (node.keylen == keylen && node.hash == hash32 && node.contentEquals(buf)) { | ||||
|                                 node.stat.add(val); | ||||
|                             if (node.contentEquals(row_addr, tail)) { | ||||
|                                 node.add(val); | ||||
|                                 break; | ||||
|                             } | ||||
|                             bucket++; | ||||
|                             if (SHOW_COLLISIONS) | ||||
|                                 cls[thread_id]++; | ||||
|                         } | ||||
|                     } | ||||
|                 })).start(); | ||||
| @@ -212,19 +230,26 @@ public class CalculateAverage_abeobk { | ||||
|             for (var thread : threads) | ||||
|                 thread.join(); | ||||
|  | ||||
|             if (SHOW_COLLISIONS) { | ||||
|                 for (int i = 0; i < cpu_cnt; i++) { | ||||
|                     System.out.println("thread-" + i + " collision = " + cls[i]); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // collect results | ||||
|             TreeMap<String, Stat> ms = new TreeMap<>(); | ||||
|             TreeMap<String, Node> ms = new TreeMap<>(); | ||||
|             for (var map : maps) { | ||||
|                 for (var node : map) { | ||||
|                     if (node == null) | ||||
|                         continue; | ||||
|                     var stat = ms.putIfAbsent(node.key(), node.stat); | ||||
|                     var stat = ms.putIfAbsent(node.key(), node); | ||||
|                     if (stat != null) | ||||
|                         stat.merge(node.stat); | ||||
|                         stat.merge(node); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             System.out.println(ms); | ||||
|             if (!SHOW_COLLISIONS) | ||||
|                 System.out.println(ms); | ||||
|         } | ||||
|     } | ||||
| } | ||||
		Reference in New Issue
	
	Block a user