Use native type, remove lots of type conversions (#618)

* less type conversion, less string cast

* adjust some comments

* fixed format issue
This commit is contained in:
Van Phu DO 2024-01-29 02:08:42 +09:00 committed by GitHub
parent d5854d65e6
commit a33ed2181b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,9 +26,9 @@ import java.nio.file.Path;
import java.nio.file.StandardOpenOption; import java.nio.file.StandardOpenOption;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import sun.misc.Unsafe; import sun.misc.Unsafe;
@ -39,7 +39,7 @@ public class CalculateAverage_abeobk {
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static final int BUCKET_SIZE = 1 << 16; private static final int BUCKET_SIZE = 1 << 16;
private static final int BUCKET_MASK = BUCKET_SIZE - 1; private static final long BUCKET_MASK = BUCKET_SIZE - 1;
private static final int MAX_STR_LEN = 100; private static final int MAX_STR_LEN = 100;
private static final int MAX_STATIONS = 10000; private static final int MAX_STATIONS = 10000;
private static final long CHUNK_SZ = 1 << 22; // 4MB chunk private static final long CHUNK_SZ = 1 << 22; // 4MB chunk
@ -56,9 +56,9 @@ public class CalculateAverage_abeobk {
0xffffffffffffffffL, }; 0xffffffffffffffffL, };
private static AtomicInteger chunk_id = new AtomicInteger(0); private static AtomicInteger chunk_id = new AtomicInteger(0);
private static AtomicReference<Node[]> mapref = new AtomicReference<>(null);
private static int chunk_cnt; private static int chunk_cnt;
private static long start_addr, end_addr; private static long start_addr, end_addr;
private static Stat[][] all_res;
private static final void debug(String s, Object... args) { private static final void debug(String s, Object... args) {
System.out.println(String.format(s, args)); System.out.println(String.format(s, args));
@ -75,57 +75,49 @@ public class CalculateAverage_abeobk {
} }
} }
static class Stat { // use native type, less conversion
Node node;
String key;
public final String toString() {
return (node.min / 10.0) + "/"
+ (Math.round(((double) node.sum / node.count)) / 10.0) + "/"
+ (node.max / 10.0);
}
Stat(Node n) {
node = n;
byte[] sbuf = new byte[MAX_STR_LEN];
long word = UNSAFE.getLong(n.addr);
long semipos_code = getSemiPosCode(word);
int keylen = 0;
while (semipos_code == 0) {
keylen += 8;
word = UNSAFE.getLong(n.addr + keylen);
semipos_code = getSemiPosCode(word);
}
keylen += Long.numberOfTrailingZeros(semipos_code) >>> 3;
UNSAFE.copyMemory(null, n.addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
}
}
static class Node { static class Node {
long addr; long addr;
long hash;
long word0; long word0;
long tail; long tail;
long sum; long sum;
long min, max;
int keylen;
int count; int count;
short min, max;
Node(long a, long t, short val) { public final String toString() {
return (min / 10.0) + "/"
+ (Math.round(((double) sum / count)) / 10.0) + "/"
+ (max / 10.0);
}
final String key() {
byte[] sbuf = new byte[MAX_STR_LEN];
UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
return new String(sbuf, 0, (int) keylen, StandardCharsets.UTF_8);
}
Node(long a, long t, int kl, long h, long val) {
addr = a; addr = a;
tail = t; tail = t;
sum = min = max = val; sum = min = max = val;
count = 1; count = 1;
keylen = kl;
hash = h;
} }
Node(long a, long w0, long t, short val) { Node(long a, long w0, long t, int kl, long h, long val) {
addr = a; addr = a;
word0 = w0; word0 = w0;
tail = t; tail = t;
sum = min = max = val; sum = min = max = val;
count = 1; count = 1;
keylen = kl;
hash = h;
} }
final void add(short val) { final void add(long val) {
sum += val; sum += val;
count++; count++;
if (val >= max) { if (val >= max) {
@ -148,17 +140,28 @@ public class CalculateAverage_abeobk {
} }
} }
final boolean contentEquals(long other_addr, long other_word0, long other_tail, int keylen) { final boolean contentEquals(long other_addr, long other_word0, long other_tail, long kl) {
if (word0 != other_word0 || tail != other_tail) if (word0 != other_word0 || tail != other_tail)
return false; return false;
// this is faster than comparision if key is short // this is faster than comparision if key is short
long xsum = 0; long xsum = 0;
int n = keylen & 0xF8; long n = kl & 0xF8;
for (int i = 8; i < n; i += 8) { for (long i = 8; i < n; i += 8) {
xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i)); xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i));
} }
return xsum == 0; return xsum == 0;
} }
final boolean contentEquals(Node other) {
if (tail != other.tail)
return false;
long n = keylen & 0xF8;
for (long i = 0; i < n; i += 8) {
if (UNSAFE.getLong(addr + i) != UNSAFE.getLong(other.addr + i))
return false;
}
return true;
}
} }
// idea from royvanrijn // idea from royvanrijn
@ -168,24 +171,24 @@ public class CalculateAverage_abeobk {
} }
// speed/collision balance // speed/collision balance
static final int xxh32(long hash) { static final long xxh32(long hash) {
long h = hash * 37; long h = hash * 37;
return (int) (h ^ (h >>> 29)); return (h ^ (h >>> 29));
} }
// great idea from merykitty (Quan Anh Mai) // great idea from merykitty (Quan Anh Mai)
static final short parseNum(long num_word, int dot_pos) { static final long parseNum(long num_word, int dot_pos) {
int shift = 28 - dot_pos; int shift = 28 - dot_pos;
long signed = (~num_word << 59) >> 63; long signed = (~num_word << 59) >> 63;
long dsmask = ~(signed & 0xFF); long dsmask = ~(signed & 0xFF);
long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
return (short) ((abs_val ^ signed) - signed); return ((abs_val ^ signed) - signed);
} }
// Thread pool worker // Thread pool worker
static final class Worker extends Thread { static final class Worker extends Thread {
final int thread_id; final int thread_id; // for debug use only
Worker(int i) { Worker(int i) {
thread_id = i; thread_id = i;
@ -195,16 +198,15 @@ public class CalculateAverage_abeobk {
@Override @Override
public void run() { public void run() {
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
int cnt = 0;
int id; int id;
int cls = 0; int cls = 0;
// process in small chunk to maintain disk locality (artsiomkorzun trick) // process in small chunk to maintain disk locality (artsiomkorzun trick)
// but keep going instead of merging
while ((id = chunk_id.getAndIncrement()) < chunk_cnt) { while ((id = chunk_id.getAndIncrement()) < chunk_cnt) {
long addr = start_addr + id * CHUNK_SZ; long addr = start_addr + id * CHUNK_SZ;
long end = Math.min(addr + CHUNK_SZ, end_addr); long end = Math.min(addr + CHUNK_SZ, end_addr);
// adjust start
// find start of line
if (id > 0) { if (id > 0) {
while (UNSAFE.getByte(addr++) != '\n') while (UNSAFE.getByte(addr++) != '\n')
; ;
@ -230,14 +232,14 @@ public class CalculateAverage_abeobk {
addr += (dot_pos >>> 3) + 3; addr += (dot_pos >>> 3) + 3;
long tail = word0 & HASH_MASKS[semi_pos]; long tail = word0 & HASH_MASKS[semi_pos];
int bucket = xxh32(tail) & BUCKET_MASK; long hash = xxh32(tail);
short val = parseNum(num_word, dot_pos); int bucket = (int) (hash & BUCKET_MASK);
long val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, tail, val); map[bucket] = new Node(row_addr, tail, semi_pos, hash, val);
cnt++;
break; break;
} }
if (node.tail == tail) { if (node.tail == tail) {
@ -263,14 +265,14 @@ public class CalculateAverage_abeobk {
addr += (dot_pos >>> 3) + 3; addr += (dot_pos >>> 3) + 3;
long tail = (word & HASH_MASKS[semi_pos]); long tail = (word & HASH_MASKS[semi_pos]);
int bucket = xxh32(word0 ^ tail) & BUCKET_MASK; long hash = xxh32(word0 ^ tail);
short val = parseNum(num_word, dot_pos); int bucket = (int) (hash & BUCKET_MASK);
long val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, val); map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash, val);
cnt++;
break; break;
} }
if (node.word0 == word0 && node.tail == tail) { if (node.word0 == word0 && node.tail == tail) {
@ -295,20 +297,20 @@ public class CalculateAverage_abeobk {
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
addr += semi_pos; addr += semi_pos;
int keylen = (int) (addr - row_addr); long keylen = addr - row_addr;
long num_word = UNSAFE.getLong(addr + 1); long num_word = UNSAFE.getLong(addr + 1);
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 4; addr += (dot_pos >>> 3) + 4;
long tail = (word & HASH_MASKS[semi_pos]); long tail = (word & HASH_MASKS[semi_pos]);
int bucket = xxh32(hash ^ tail) & BUCKET_MASK; hash = xxh32(hash ^ tail);
short val = parseNum(num_word, dot_pos); int bucket = (int) (hash & BUCKET_MASK);
long val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, val); map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash, val);
cnt++;
break; break;
} }
if (node.contentEquals(row_addr, word0, tail, keylen)) { if (node.contentEquals(row_addr, word0, tail, keylen)) {
@ -322,18 +324,36 @@ public class CalculateAverage_abeobk {
} }
} }
// merge is cheaper than string casting (artsiomkorzun)
while (!mapref.compareAndSet(null, map)) {
var other_map = mapref.getAndSet(null);
if (other_map != null) {
for (int i = 0; i < other_map.length; i++) {
var other = other_map[i];
if (other == null)
continue;
int bucket = (int) (other.hash & BUCKET_MASK);
while (true) {
var node = map[bucket];
if (node == null) {
map[bucket] = other;
break;
}
if (node.contentEquals(other)) {
node.merge(other);
break;
}
bucket++;
if (SHOW_ANALYSIS)
cls++;
}
}
}
}
if (SHOW_ANALYSIS) { if (SHOW_ANALYSIS) {
debug("Thread %d collision = %d", thread_id, cls); debug("Thread %d collision = %d", thread_id, cls);
} }
Stat[] stats = new Stat[cnt];
int i = 0;
for (var node : map) {
if (node != null) {
stats[i++] = new Stat(node);
}
}
all_res[thread_id] = stats;
} }
} }
@ -366,23 +386,22 @@ public class CalculateAverage_abeobk {
// only use all cpus on large file // only use all cpus on large file
int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT; int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT;
chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ); chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ);
all_res = new Stat[cpu_cnt][];
List<Worker> workers = IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList(); // spawn workers
for (var w : workers) for (var w : IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList()) {
w.join(); w.join();
// collect all results
TreeMap<String, Stat> ms = new TreeMap<>();
for (var res : all_res) {
for (var s : res) {
var stat = ms.putIfAbsent(s.key, s);
if (stat != null)
stat.node.merge(s.node);
}
} }
// print output // collect results
TreeMap<String, Node> ms = new TreeMap<>();
for (var crr : mapref.get()) {
if (crr == null)
continue;
var prev = ms.putIfAbsent(crr.key(), crr);
if (prev != null)
prev.merge(crr);
}
// print result
System.out.println(ms); System.out.println(ms);
System.out.close(); System.out.close();
} }