Optimized with less constructor args + low collision mixer (#420)

* use all CPUs

* use graal

* optimized with less constructor arg

* optimized with low collision mixer
This commit is contained in:
Van Phu DO 2024-01-16 02:53:31 +09:00 committed by GitHub
parent ecab306338
commit 677d94e5cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,6 +28,8 @@ import java.util.TreeMap;
import sun.misc.Unsafe; import sun.misc.Unsafe;
public class CalculateAverage_abeobk { public class CalculateAverage_abeobk {
private static final boolean SHOW_COLLISIONS = false;
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static final int BUCKET_SIZE = 1 << 16; private static final int BUCKET_SIZE = 1 << 16;
private static final int BUCKET_MASK = BUCKET_SIZE - 1; private static final int BUCKET_MASK = BUCKET_SIZE - 1;
@ -55,69 +57,55 @@ public class CalculateAverage_abeobk {
} }
} }
// stat static class Node {
private static class Stat { long addr;
private int min; long tail;
private int max; int min, max;
private long sum; int count;
private int count; long sum;
Stat(int v) { String key() {
sum = min = max = v; byte[] sbuf = new byte[MAX_STR_LEN];
count = 1; int keylen = (int) (tail >>> 56);
} UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
return new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
void add(int val) {
min = Math.min(val, min);
max = Math.max(val, max);
sum += val;
count++;
}
void merge(Stat other) {
min = Math.min(other.min, min);
max = Math.max(other.max, max);
sum += other.sum;
count += other.count;
} }
public String toString() { public String toString() {
return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1); return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1);
} }
}
static class Node { Node(long a, long t, int val) {
long addr;
int keylen;
int hash;
long[] buf = new long[13];
Stat stat;
String key() {
byte[] buf = new byte[MAX_STR_LEN];
UNSAFE.copyMemory(null, addr, buf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
return new String(buf, 0, keylen, StandardCharsets.UTF_8);
}
Node(long a, int kl, int h, int v, long[] b) {
stat = new Stat(v);
addr = a; addr = a;
keylen = kl; tail = t;
hash = h; sum = min = max = val;
System.arraycopy(b, 0, buf, 0, Math.ceilDiv(kl, 8)); count = 1;
} }
boolean contentEquals(final long[] other_buf) { void add(int val) {
int k = keylen / 8; min = Math.min(min, val);
int r = keylen % 8; max = Math.max(max, val);
// Since the city name is most likely shorter than 16 characters sum += val;
// this should be faster than typical conditional checks count++;
long sum = 0; }
for (int i = 0; i < k; i++) {
sum += buf[i] ^ other_buf[i]; void merge(Node other) {
min = Math.min(min, other.min);
max = Math.max(max, other.max);
sum += other.sum;
count += other.count;
}
boolean contentEquals(long other_addr, long other_tail) {
if (tail != other_tail) // compare tail & length at the same time
return false;
long my_addr = addr;
int nl = (int) (tail >> 59);
for (int i = 0; i < nl; i++, my_addr += 8, other_addr += 8) {
if (UNSAFE.getLong(my_addr) != UNSAFE.getLong(other_addr))
return false;
} }
sum += (buf[k] ^ other_buf[k]) & HASH_MASKS[r]; return true;
return sum == 0;
} }
} }
@ -135,55 +123,83 @@ public class CalculateAverage_abeobk {
return ptrs; return ptrs;
} }
static final long getSemiPosCode(final long word) {
long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;;
return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L);
}
// very low collision mixer
// idea from https://github.com/Cyan4973/xxHash/tree/dev
// zero collision on test data
static final int xxh32(long hash) {
final int p1 = 0x85EBCA77; // prime
final int p2 = 0xC2B2AE3D; // prime
int low = (int) hash;
int high = (int) (hash >>> 32);
low ^= low >> 15;
low *= p1;
high ^= high >> 13;
high *= p2;
var h = low ^ high;
return h;
}
public static void main(String[] args) throws InterruptedException, IOException { public static void main(String[] args) throws InterruptedException, IOException {
int cpu_cnt = Runtime.getRuntime().availableProcessors();
try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
long file_size = file.size(); long file_size = file.size();
long end_addr = start_addr + file_size; long end_addr = start_addr + file_size;
// only use all cpus on large file
int cpu_cnt = file_size < 1e6 ? 1 : Runtime.getRuntime().availableProcessors();
long chunk_size = Math.ceilDiv(file_size, cpu_cnt); long chunk_size = Math.ceilDiv(file_size, cpu_cnt);
// processing // processing
var threads = new Thread[cpu_cnt]; var threads = new Thread[cpu_cnt];
var maps = new Node[cpu_cnt][]; var maps = new Node[cpu_cnt][];
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
int[] cls = new int[cpu_cnt];
for (int i = 0; i < cpu_cnt; i++) { for (int i = 0; i < cpu_cnt; i++) {
int thread_id = i; int thread_id = i;
long start = ptrs[i]; long start = ptrs[i];
long end = ptrs[i + 1]; long end = ptrs[i + 1];
maps[i] = new Node[BUCKET_SIZE + 16]; // extra space for collisions maps[i] = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
(threads[i] = new Thread(() -> { (threads[i] = new Thread(() -> {
long addr = start; long addr = start;
var map = maps[thread_id]; var map = maps[thread_id];
long[] buf = new long[13];
// parse loop // parse loop
while (addr < end) { while (addr < end) {
int idx = 0;
long hash = 0; long hash = 0;
long word = 0; long word = 0;
long row_addr = addr; long row_addr = addr;
int semi_pos = 8; int semi_pos = 8;
while (semi_pos == 8) { word = UNSAFE.getLong(addr);
long semipos_code = getSemiPosCode(word);
while (semipos_code == 0) {
hash ^= word;
addr += 8;
word = UNSAFE.getLong(addr); word = UNSAFE.getLong(addr);
buf[idx++] = word; semipos_code = getSemiPosCode(word);
// idea from thomaswue & royvanrijn
long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;;
long semipos_code = (xor_semi - 0x0101010101010101L) & ~xor_semi & 0x8080808080808080L;
semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
addr += semi_pos;
hash ^= word & HASH_MASKS[semi_pos];
} }
int hash32 = (int) (hash ^ (hash >>> 31)); semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
int keylen = (int) (addr - row_addr); long tail = word & HASH_MASKS[semi_pos];
hash ^= tail;
addr += semi_pos;
int hash32 = xxh32(hash);
long keylen = (addr - row_addr);
tail = tail | (keylen << 56);
addr++;
// great idea from merykitty (Quan Anh Mai) // great idea from merykitty (Quan Anh Mai)
long num_word = UNSAFE.getLong(++addr); long num_word = UNSAFE.getLong(addr);
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 3; addr += (dot_pos >>> 3) + 3;
int shift = 28 - dot_pos; int shift = 28 - dot_pos;
long signed = (~num_word << 59) >> 63; long signed = (~num_word << 59) >> 63;
long dsmask = ~(signed & 0xFF); long dsmask = ~(signed & 0xFF);
@ -195,14 +211,16 @@ public class CalculateAverage_abeobk {
while (true) { while (true) {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, keylen, hash32, val, buf); map[bucket] = new Node(row_addr, tail, val);
break; break;
} }
if (node.keylen == keylen && node.hash == hash32 && node.contentEquals(buf)) { if (node.contentEquals(row_addr, tail)) {
node.stat.add(val); node.add(val);
break; break;
} }
bucket++; bucket++;
if (SHOW_COLLISIONS)
cls[thread_id]++;
} }
} }
})).start(); })).start();
@ -212,19 +230,26 @@ public class CalculateAverage_abeobk {
for (var thread : threads) for (var thread : threads)
thread.join(); thread.join();
if (SHOW_COLLISIONS) {
for (int i = 0; i < cpu_cnt; i++) {
System.out.println("thread-" + i + " collision = " + cls[i]);
}
}
// collect results // collect results
TreeMap<String, Stat> ms = new TreeMap<>(); TreeMap<String, Node> ms = new TreeMap<>();
for (var map : maps) { for (var map : maps) {
for (var node : map) { for (var node : map) {
if (node == null) if (node == null)
continue; continue;
var stat = ms.putIfAbsent(node.key(), node.stat); var stat = ms.putIfAbsent(node.key(), node);
if (stat != null) if (stat != null)
stat.merge(node.stat); stat.merge(node);
} }
} }
System.out.println(ms); if (!SHOW_COLLISIONS)
System.out.println(ms);
} }
} }
} }