Use native type, remove lots of type conversions (#618)

* less type conversion, less string cast

* adjust some comments

* fixed format issue
This commit is contained in:
Van Phu DO 2024-01-29 02:08:42 +09:00 committed by GitHub
parent d5854d65e6
commit a33ed2181b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,9 +26,9 @@ import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;
import sun.misc.Unsafe;
@ -39,7 +39,7 @@ public class CalculateAverage_abeobk {
private static final String FILE = "./measurements.txt";
private static final int BUCKET_SIZE = 1 << 16;
private static final int BUCKET_MASK = BUCKET_SIZE - 1;
private static final long BUCKET_MASK = BUCKET_SIZE - 1;
private static final int MAX_STR_LEN = 100;
private static final int MAX_STATIONS = 10000;
private static final long CHUNK_SZ = 1 << 22; // 4MB chunk
@ -56,9 +56,9 @@ public class CalculateAverage_abeobk {
0xffffffffffffffffL, };
private static AtomicInteger chunk_id = new AtomicInteger(0);
private static AtomicReference<Node[]> mapref = new AtomicReference<>(null);
private static int chunk_cnt;
private static long start_addr, end_addr;
private static Stat[][] all_res;
private static final void debug(String s, Object... args) {
System.out.println(String.format(s, args));
@ -75,57 +75,49 @@ public class CalculateAverage_abeobk {
}
}
static class Stat {
Node node;
String key;
public final String toString() {
return (node.min / 10.0) + "/"
+ (Math.round(((double) node.sum / node.count)) / 10.0) + "/"
+ (node.max / 10.0);
}
Stat(Node n) {
node = n;
byte[] sbuf = new byte[MAX_STR_LEN];
long word = UNSAFE.getLong(n.addr);
long semipos_code = getSemiPosCode(word);
int keylen = 0;
while (semipos_code == 0) {
keylen += 8;
word = UNSAFE.getLong(n.addr + keylen);
semipos_code = getSemiPosCode(word);
}
keylen += Long.numberOfTrailingZeros(semipos_code) >>> 3;
UNSAFE.copyMemory(null, n.addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
}
}
// use native type, less conversion
static class Node {
long addr;
long hash;
long word0;
long tail;
long sum;
long min, max;
int keylen;
int count;
short min, max;
Node(long a, long t, short val) {
public final String toString() {
return (min / 10.0) + "/"
+ (Math.round(((double) sum / count)) / 10.0) + "/"
+ (max / 10.0);
}
final String key() {
byte[] sbuf = new byte[MAX_STR_LEN];
UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
return new String(sbuf, 0, (int) keylen, StandardCharsets.UTF_8);
}
Node(long a, long t, int kl, long h, long val) {
addr = a;
tail = t;
sum = min = max = val;
count = 1;
keylen = kl;
hash = h;
}
Node(long a, long w0, long t, short val) {
Node(long a, long w0, long t, int kl, long h, long val) {
addr = a;
word0 = w0;
tail = t;
sum = min = max = val;
count = 1;
keylen = kl;
hash = h;
}
final void add(short val) {
final void add(long val) {
sum += val;
count++;
if (val >= max) {
@ -148,17 +140,28 @@ public class CalculateAverage_abeobk {
}
}
final boolean contentEquals(long other_addr, long other_word0, long other_tail, int keylen) {
final boolean contentEquals(long other_addr, long other_word0, long other_tail, long kl) {
if (word0 != other_word0 || tail != other_tail)
return false;
// this is faster than comparision if key is short
long xsum = 0;
int n = keylen & 0xF8;
for (int i = 8; i < n; i += 8) {
long n = kl & 0xF8;
for (long i = 8; i < n; i += 8) {
xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i));
}
return xsum == 0;
}
final boolean contentEquals(Node other) {
if (tail != other.tail)
return false;
long n = keylen & 0xF8;
for (long i = 0; i < n; i += 8) {
if (UNSAFE.getLong(addr + i) != UNSAFE.getLong(other.addr + i))
return false;
}
return true;
}
}
// idea from royvanrijn
@ -168,24 +171,24 @@ public class CalculateAverage_abeobk {
}
// speed/collision balance
static final int xxh32(long hash) {
static final long xxh32(long hash) {
long h = hash * 37;
return (int) (h ^ (h >>> 29));
return (h ^ (h >>> 29));
}
// great idea from merykitty (Quan Anh Mai)
static final short parseNum(long num_word, int dot_pos) {
static final long parseNum(long num_word, int dot_pos) {
int shift = 28 - dot_pos;
long signed = (~num_word << 59) >> 63;
long dsmask = ~(signed & 0xFF);
long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
return (short) ((abs_val ^ signed) - signed);
return ((abs_val ^ signed) - signed);
}
// Thread pool worker
static final class Worker extends Thread {
final int thread_id;
final int thread_id; // for debug use only
Worker(int i) {
thread_id = i;
@ -195,16 +198,15 @@ public class CalculateAverage_abeobk {
@Override
public void run() {
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
int cnt = 0;
int id;
int cls = 0;
// process in small chunk to maintain disk locality (artsiomkorzun trick)
// but keep going instead of merging
while ((id = chunk_id.getAndIncrement()) < chunk_cnt) {
long addr = start_addr + id * CHUNK_SZ;
long end = Math.min(addr + CHUNK_SZ, end_addr);
// adjust start
// find start of line
if (id > 0) {
while (UNSAFE.getByte(addr++) != '\n')
;
@ -230,14 +232,14 @@ public class CalculateAverage_abeobk {
addr += (dot_pos >>> 3) + 3;
long tail = word0 & HASH_MASKS[semi_pos];
int bucket = xxh32(tail) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos);
long hash = xxh32(tail);
int bucket = (int) (hash & BUCKET_MASK);
long val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
map[bucket] = new Node(row_addr, tail, val);
cnt++;
map[bucket] = new Node(row_addr, tail, semi_pos, hash, val);
break;
}
if (node.tail == tail) {
@ -263,14 +265,14 @@ public class CalculateAverage_abeobk {
addr += (dot_pos >>> 3) + 3;
long tail = (word & HASH_MASKS[semi_pos]);
int bucket = xxh32(word0 ^ tail) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos);
long hash = xxh32(word0 ^ tail);
int bucket = (int) (hash & BUCKET_MASK);
long val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, val);
cnt++;
map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash, val);
break;
}
if (node.word0 == word0 && node.tail == tail) {
@ -295,20 +297,20 @@ public class CalculateAverage_abeobk {
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
addr += semi_pos;
int keylen = (int) (addr - row_addr);
long keylen = addr - row_addr;
long num_word = UNSAFE.getLong(addr + 1);
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 4;
long tail = (word & HASH_MASKS[semi_pos]);
int bucket = xxh32(hash ^ tail) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos);
hash = xxh32(hash ^ tail);
int bucket = (int) (hash & BUCKET_MASK);
long val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, val);
cnt++;
map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash, val);
break;
}
if (node.contentEquals(row_addr, word0, tail, keylen)) {
@ -322,18 +324,36 @@ public class CalculateAverage_abeobk {
}
}
// merge is cheaper than string casting (artsiomkorzun)
while (!mapref.compareAndSet(null, map)) {
var other_map = mapref.getAndSet(null);
if (other_map != null) {
for (int i = 0; i < other_map.length; i++) {
var other = other_map[i];
if (other == null)
continue;
int bucket = (int) (other.hash & BUCKET_MASK);
while (true) {
var node = map[bucket];
if (node == null) {
map[bucket] = other;
break;
}
if (node.contentEquals(other)) {
node.merge(other);
break;
}
bucket++;
if (SHOW_ANALYSIS)
cls++;
}
}
}
}
if (SHOW_ANALYSIS) {
debug("Thread %d collision = %d", thread_id, cls);
}
Stat[] stats = new Stat[cnt];
int i = 0;
for (var node : map) {
if (node != null) {
stats[i++] = new Stat(node);
}
}
all_res[thread_id] = stats;
}
}
@ -366,23 +386,22 @@ public class CalculateAverage_abeobk {
// only use all cpus on large file
int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT;
chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ);
all_res = new Stat[cpu_cnt][];
List<Worker> workers = IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList();
for (var w : workers)
// spawn workers
for (var w : IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList()) {
w.join();
// collect all results
TreeMap<String, Stat> ms = new TreeMap<>();
for (var res : all_res) {
for (var s : res) {
var stat = ms.putIfAbsent(s.key, s);
if (stat != null)
stat.node.merge(s.node);
}
}
// print output
// collect results
TreeMap<String, Node> ms = new TreeMap<>();
for (var crr : mapref.get()) {
if (crr == null)
continue;
var prev = ms.putIfAbsent(crr.key(), crr);
if (prev != null)
prev.merge(crr);
}
// print result
System.out.println(ms);
System.out.close();
}