Native build, less memory acess, improved hash mixing (#449)
This commit is contained in:
		| @@ -24,11 +24,12 @@ import java.nio.channels.FileChannel.MapMode; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.nio.file.Path; | ||||
| import java.nio.file.StandardOpenOption; | ||||
| import java.util.Arrays; | ||||
| import java.util.TreeMap; | ||||
| import sun.misc.Unsafe; | ||||
|  | ||||
| public class CalculateAverage_abeobk { | ||||
|     private static final boolean SHOW_COLLISIONS = false; | ||||
|     private static final boolean SHOW_ANALYSIS = false; | ||||
|  | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|     private static final int BUCKET_SIZE = 1 << 16; | ||||
| @@ -99,13 +100,13 @@ public class CalculateAverage_abeobk { | ||||
|         boolean contentEquals(long other_addr, long other_tail) { | ||||
|             if (tail != other_tail) // compare tail & length at the same time | ||||
|                 return false; | ||||
|             long my_addr = addr; | ||||
|             int nl = (int) (tail >> 59); | ||||
|             for (int i = 0; i < nl; i++, my_addr += 8, other_addr += 8) { | ||||
|                 if (UNSAFE.getLong(my_addr) != UNSAFE.getLong(other_addr)) | ||||
|                     return false; | ||||
|             // this is faster than comparision if key is short | ||||
|             long xsum = 0; | ||||
|             int n = ((int) (tail >>> 56)) & 0xF8; | ||||
|             for (int i = 0; i < n; i += 8) { | ||||
|                 xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i)); | ||||
|             } | ||||
|             return true; | ||||
|             return xsum == 0; | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -123,6 +124,7 @@ public class CalculateAverage_abeobk { | ||||
|         return ptrs; | ||||
|     } | ||||
|  | ||||
|     // idea from royvanrijn | ||||
|     static final long getSemiPosCode(final long word) { | ||||
|         long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;; | ||||
|         return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L); | ||||
| @@ -133,17 +135,164 @@ public class CalculateAverage_abeobk { | ||||
|     // zero collision on test data | ||||
|     static final int xxh32(long hash) { | ||||
|         final int p1 = 0x85EBCA77; // prime | ||||
|         final int p2 = 0xC2B2AE3D; // prime | ||||
|         final int p2 = 0x165667B1; // prime | ||||
|         int low = (int) hash; | ||||
|         int high = (int) (hash >>> 32); | ||||
|         low ^= low >> 15; | ||||
|         low *= p1; | ||||
|         high ^= high >> 13; | ||||
|         high *= p2; | ||||
|         var h = low ^ high; | ||||
|         int high = (int) (hash >>> 31); | ||||
|         int h = low + high; | ||||
|         h ^= h >> 15; | ||||
|         h *= p1; | ||||
|         h ^= h >> 13; | ||||
|         h *= p2; | ||||
|         h ^= h >> 11; | ||||
|         return h; | ||||
|     } | ||||
|  | ||||
|     // great idea from merykitty (Quan Anh Mai) | ||||
|     static final int parseNum(long num_word, int dot_pos) { | ||||
|         int shift = 28 - dot_pos; | ||||
|         long signed = (~num_word << 59) >> 63; | ||||
|         long dsmask = ~(signed & 0xFF); | ||||
|         long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; | ||||
|         long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; | ||||
|         return (int) ((abs_val ^ signed) - signed); | ||||
|     } | ||||
|  | ||||
|     // optimize for contest | ||||
|     // save as much slow memory access as possible | ||||
|     // about 50% key < 8chars, 25% key bettween 8-10 chars | ||||
|     // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... | ||||
|     static final Node[] parse(int thread_id, long start, long end, int[] cls) { | ||||
|         long addr = start; | ||||
|         var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions | ||||
|         // parse loop | ||||
|         while (addr < end) { | ||||
|             long row_addr = addr; | ||||
|             long tail = 0; | ||||
|             long hash = 0; | ||||
|             int val = 0; | ||||
|             int bucket = 0; | ||||
|  | ||||
|             long word = UNSAFE.getLong(addr); | ||||
|             long semipos_code = getSemiPosCode(word); | ||||
|  | ||||
|             // about 50% chance key < 8 chars | ||||
|             if (semipos_code != 0) { | ||||
|                 int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; | ||||
|                 addr += semi_pos; | ||||
|                 tail = (word & HASH_MASKS[semi_pos]); | ||||
|                 bucket = xxh32(tail) & BUCKET_MASK; | ||||
|                 long keylen = (addr - row_addr); | ||||
|                 tail |= (keylen << 56); | ||||
|                 long num_word = UNSAFE.getLong(++addr); | ||||
|                 int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); | ||||
|                 val = parseNum(num_word, dot_pos); | ||||
|                 addr += (dot_pos >>> 3) + 3; | ||||
|  | ||||
|                 while (true) { | ||||
|                     var node = map[bucket]; | ||||
|                     if (node == null) { | ||||
|                         map[bucket] = new Node(row_addr, tail, val); | ||||
|                         break; | ||||
|                     } | ||||
|                     if (node.tail == tail) { | ||||
|                         node.add(val); | ||||
|                         break; | ||||
|                     } | ||||
|                     bucket++; | ||||
|                     if (SHOW_ANALYSIS) | ||||
|                         cls[thread_id]++; | ||||
|                 } | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             hash ^= word; | ||||
|             addr += 8; | ||||
|             word = UNSAFE.getLong(addr); | ||||
|             semipos_code = getSemiPosCode(word); | ||||
|             // frist byte semicolon ~13% | ||||
|             if (semipos_code == 0x80) { | ||||
|                 bucket = xxh32(hash) & BUCKET_MASK; | ||||
|                 tail = 8L << 56; | ||||
|                 long num_word = word >>> 8; | ||||
|                 int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); | ||||
|                 val = parseNum(num_word, dot_pos); | ||||
|                 addr += (dot_pos >>> 3) + 4; | ||||
|  | ||||
|                 while (true) { | ||||
|                     var node = map[bucket]; | ||||
|                     if (node == null) { | ||||
|                         map[bucket] = new Node(row_addr, tail, val); | ||||
|                         break; | ||||
|                     } | ||||
|                     if (UNSAFE.getLong(node.addr) == UNSAFE.getLong(row_addr)) { | ||||
|                         node.add(val); | ||||
|                         break; | ||||
|                     } | ||||
|                     bucket++; | ||||
|                     if (SHOW_ANALYSIS) | ||||
|                         cls[thread_id]++; | ||||
|                 } | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             while (semipos_code == 0) { | ||||
|                 hash ^= word; | ||||
|                 addr += 8; | ||||
|                 word = UNSAFE.getLong(addr); | ||||
|                 semipos_code = getSemiPosCode(word); | ||||
|             } | ||||
|  | ||||
|             int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; | ||||
|             addr += semi_pos; | ||||
|             tail = (word & HASH_MASKS[semi_pos]); | ||||
|             hash ^= tail; | ||||
|             bucket = xxh32(hash) & BUCKET_MASK; | ||||
|             long keylen = (addr - row_addr); | ||||
|             tail |= (keylen << 56); | ||||
|  | ||||
|             ++addr; | ||||
|             long num_word = UNSAFE.getLong(addr); | ||||
|             int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); | ||||
|             val = parseNum(num_word, dot_pos); | ||||
|             addr += (dot_pos >>> 3) + 3; | ||||
|  | ||||
|             if (keylen < 16) { | ||||
|                 while (true) { | ||||
|                     var node = map[bucket]; | ||||
|                     if (node == null) { | ||||
|                         map[bucket] = new Node(row_addr, tail, val); | ||||
|                         break; | ||||
|                     } | ||||
|                     if (node.tail == tail && (UNSAFE.getLong(node.addr) == UNSAFE.getLong(row_addr))) { | ||||
|                         node.add(val); | ||||
|                         break; | ||||
|                     } | ||||
|                     bucket++; | ||||
|                     if (SHOW_ANALYSIS) | ||||
|                         cls[thread_id]++; | ||||
|                 } | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             // longer key | ||||
|             while (true) { | ||||
|                 var node = map[bucket]; | ||||
|                 if (node == null) { | ||||
|                     map[bucket] = new Node(row_addr, tail, val); | ||||
|                     break; | ||||
|                 } | ||||
|                 if (node.contentEquals(row_addr, tail)) { | ||||
|                     node.add(val); | ||||
|                     break; | ||||
|                 } | ||||
|                 bucket++; | ||||
|                 if (SHOW_ANALYSIS) | ||||
|                     cls[thread_id]++; | ||||
|             } | ||||
|         } | ||||
|         return map; | ||||
|     } | ||||
|  | ||||
|     public static void main(String[] args) throws InterruptedException, IOException { | ||||
|         try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { | ||||
|             long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); | ||||
| @@ -158,71 +307,14 @@ public class CalculateAverage_abeobk { | ||||
|             var threads = new Thread[cpu_cnt]; | ||||
|             var maps = new Node[cpu_cnt][]; | ||||
|             var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); | ||||
|             int[] cls = new int[cpu_cnt]; | ||||
|  | ||||
|             int[] cls = new int[cpu_cnt]; // collision | ||||
|             int[] lenhist = new int[64]; // length histogram | ||||
|  | ||||
|             for (int i = 0; i < cpu_cnt; i++) { | ||||
|                 int thread_id = i; | ||||
|                 long start = ptrs[i]; | ||||
|                 long end = ptrs[i + 1]; | ||||
|                 maps[i] = new Node[BUCKET_SIZE + 10000]; // extra space for collisions | ||||
|  | ||||
|                 (threads[i] = new Thread(() -> { | ||||
|                     long addr = start; | ||||
|                     var map = maps[thread_id]; | ||||
|                     // parse loop | ||||
|                     while (addr < end) { | ||||
|                         long hash = 0; | ||||
|                         long word = 0; | ||||
|                         long row_addr = addr; | ||||
|                         int semi_pos = 8; | ||||
|                         word = UNSAFE.getLong(addr); | ||||
|                         long semipos_code = getSemiPosCode(word); | ||||
|  | ||||
|                         while (semipos_code == 0) { | ||||
|                             hash ^= word; | ||||
|                             addr += 8; | ||||
|                             word = UNSAFE.getLong(addr); | ||||
|                             semipos_code = getSemiPosCode(word); | ||||
|                         } | ||||
|  | ||||
|                         semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; | ||||
|                         long tail = word & HASH_MASKS[semi_pos]; | ||||
|                         hash ^= tail; | ||||
|                         addr += semi_pos; | ||||
|  | ||||
|                         int hash32 = xxh32(hash); | ||||
|                         long keylen = (addr - row_addr); | ||||
|                         tail = tail | (keylen << 56); | ||||
|  | ||||
|                         addr++; | ||||
|  | ||||
|                         // great idea from merykitty (Quan Anh Mai) | ||||
|                         long num_word = UNSAFE.getLong(addr); | ||||
|                         int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); | ||||
|                         addr += (dot_pos >>> 3) + 3; | ||||
|                         int shift = 28 - dot_pos; | ||||
|                         long signed = (~num_word << 59) >> 63; | ||||
|                         long dsmask = ~(signed & 0xFF); | ||||
|                         long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; | ||||
|                         long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; | ||||
|                         int val = (int) ((abs_val ^ signed) - signed); | ||||
|  | ||||
|                         int bucket = (hash32 & BUCKET_MASK); | ||||
|                         while (true) { | ||||
|                             var node = map[bucket]; | ||||
|                             if (node == null) { | ||||
|                                 map[bucket] = new Node(row_addr, tail, val); | ||||
|                                 break; | ||||
|                             } | ||||
|                             if (node.contentEquals(row_addr, tail)) { | ||||
|                                 node.add(val); | ||||
|                                 break; | ||||
|                             } | ||||
|                             bucket++; | ||||
|                             if (SHOW_COLLISIONS) | ||||
|                                 cls[thread_id]++; | ||||
|                         } | ||||
|                     } | ||||
|                 (threads[thread_id] = new Thread(() -> { | ||||
|                     maps[thread_id] = parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1], cls); | ||||
|                 })).start(); | ||||
|             } | ||||
|  | ||||
| @@ -230,7 +322,7 @@ public class CalculateAverage_abeobk { | ||||
|             for (var thread : threads) | ||||
|                 thread.join(); | ||||
|  | ||||
|             if (SHOW_COLLISIONS) { | ||||
|             if (SHOW_ANALYSIS) { | ||||
|                 for (int i = 0; i < cpu_cnt; i++) { | ||||
|                     System.out.println("thread-" + i + " collision = " + cls[i]); | ||||
|                 } | ||||
| @@ -242,13 +334,22 @@ public class CalculateAverage_abeobk { | ||||
|                 for (var node : map) { | ||||
|                     if (node == null) | ||||
|                         continue; | ||||
|                     if (SHOW_ANALYSIS) { | ||||
|                         int kl = (int) (node.tail >>> 56) & (lenhist.length - 1); | ||||
|                         lenhist[kl] += node.count; | ||||
|                     } | ||||
|                     var stat = ms.putIfAbsent(node.key(), node); | ||||
|                     if (stat != null) | ||||
|                         stat.merge(node); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             if (!SHOW_COLLISIONS) | ||||
|             if (SHOW_ANALYSIS) { | ||||
|                 System.out.println("total=" + Arrays.stream(lenhist).sum()); | ||||
|                 System.out.println("length_histogram = " | ||||
|                         + Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray())); | ||||
|             } | ||||
|             else | ||||
|                 System.out.println(ms); | ||||
|         } | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user