Simplify Node class with less field, improve hash mix speed (#584)

* Simplify Node class with less field, improve hash mix speed

* remove some ops, a bit faster

* more inline, little bit faster but not sure
This commit is contained in:
Van Phu DO 2024-01-26 06:57:04 +09:00 committed by GitHub
parent ce9455a584
commit 271bdfb032
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 67 deletions

View File

@ -20,6 +20,8 @@ sdk use java 21.0.2-graal 1>&2
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
if [ ! -f target/CalculateAverage_abeobk_image ]; then
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -dsa -march=native -R:MaxHeapSize=128m -H:-GenLoopSafepoints -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk
fi

View File

@ -39,6 +39,7 @@ public class CalculateAverage_abeobk {
private static final int BUCKET_SIZE = 1 << 16;
private static final int BUCKET_MASK = BUCKET_SIZE - 1;
private static final int MAX_STR_LEN = 100;
private static final int MAX_STATIONS = 10000;
private static final Unsafe UNSAFE = initUnsafe();
private static final long[] HASH_MASKS = new long[]{
0x0L,
@ -66,6 +67,33 @@ public class CalculateAverage_abeobk {
}
}
static class Stat {
Node node;
String key;
public final String toString() {
return (node.min / 10.0) + "/"
+ (Math.round(((double) node.sum / node.count)) / 10.0) + "/"
+ (node.max / 10.0);
}
Stat(Node n) {
node = n;
byte[] sbuf = new byte[MAX_STR_LEN];
long word = UNSAFE.getLong(n.addr);
long semipos_code = getSemiPosCode(word);
int keylen = 0;
while (semipos_code == 0) {
keylen += 8;
word = UNSAFE.getLong(n.addr + keylen);
semipos_code = getSemiPosCode(word);
}
keylen += Long.numberOfTrailingZeros(semipos_code) >>> 3;
UNSAFE.copyMemory(null, n.addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
}
}
static class Node {
long addr;
long word0;
@ -73,37 +101,23 @@ public class CalculateAverage_abeobk {
long sum;
int count;
short min, max;
int keylen;
String key;
void calcKey() {
byte[] sbuf = new byte[MAX_STR_LEN];
UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
}
public String toString() {
return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1);
}
Node(long a, long t, short val, int kl) {
Node(long a, long t, short val) {
addr = a;
tail = t;
keylen = kl;
sum = min = max = val;
count = 1;
}
Node(long a, long w0, long t, short val, int kl) {
Node(long a, long w0, long t, short val) {
addr = a;
word0 = w0;
tail = t;
keylen = kl;
sum = min = max = val;
count = 1;
}
void add(short val) {
final void add(short val) {
sum += val;
count++;
if (val >= max) {
@ -115,7 +129,7 @@ public class CalculateAverage_abeobk {
}
}
void merge(Node other) {
final void merge(Node other) {
sum += other.sum;
count += other.count;
if (other.max > max) {
@ -126,8 +140,8 @@ public class CalculateAverage_abeobk {
}
}
boolean contentEquals(long other_addr, long other_word0, long other_tail) {
if (tail != other_tail || word0 != other_word0)
final boolean contentEquals(long other_addr, long other_word0, long other_tail, int keylen) {
if (word0 != other_word0 || tail != other_tail)
return false;
// this is faster than comparision if key is short
long xsum = 0;
@ -161,11 +175,8 @@ public class CalculateAverage_abeobk {
// speed/collision balance
static final int xxh32(long hash) {
final int p1 = 0x85EBCA77; // prime
int low = (int) hash;
int high = (int) (hash >>> 33);
int h = (low * p1) ^ high;
return h ^ (h >>> 17);
long h = hash * 37;
return (int) (h ^ (h >>> 29));
}
// great idea from merykitty (Quan Anh Mai)
@ -185,11 +196,10 @@ public class CalculateAverage_abeobk {
static final Node[] parse(int thread_id, long start, long end) {
int cls = 0;
long addr = start;
var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
// parse loop
while (addr < end) {
long row_addr = addr;
long hash = 0;
long word0 = UNSAFE.getLong(addr);
long semipos_code = getSemiPosCode(word0);
@ -202,14 +212,14 @@ public class CalculateAverage_abeobk {
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 3;
long tail = (word0 & HASH_MASKS[semi_pos]);
long tail = word0 & HASH_MASKS[semi_pos];
int bucket = xxh32(tail) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
map[bucket] = new Node(row_addr, tail, val, semi_pos);
map[bucket] = new Node(row_addr, tail, val);
break;
}
if (node.tail == tail) {
@ -223,28 +233,25 @@ public class CalculateAverage_abeobk {
continue;
}
hash ^= word0;
addr += 8;
long word = UNSAFE.getLong(addr);
semipos_code = getSemiPosCode(word);
// 43% chance
if (semipos_code != 0) {
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
addr += semi_pos;
int keylen = (int) (addr - row_addr);
long num_word = UNSAFE.getLong(addr + 1);
addr += semi_pos + 1;
long num_word = UNSAFE.getLong(addr);
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 4;
addr += (dot_pos >>> 3) + 3;
long tail = (word & HASH_MASKS[semi_pos]);
hash ^= tail;
int bucket = xxh32(hash) & BUCKET_MASK;
int bucket = xxh32(word0 ^ tail) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, val, keylen);
map[bucket] = new Node(row_addr, word0, tail, val);
break;
}
if (node.word0 == word0 && node.tail == tail) {
@ -258,6 +265,9 @@ public class CalculateAverage_abeobk {
continue;
}
// why not going for more? tested, slower
long hash = word0;
while (semipos_code == 0) {
hash ^= word;
addr += 8;
@ -273,17 +283,16 @@ public class CalculateAverage_abeobk {
addr += (dot_pos >>> 3) + 4;
long tail = (word & HASH_MASKS[semi_pos]);
hash ^= tail;
int bucket = xxh32(hash) & BUCKET_MASK;
int bucket = xxh32(hash ^ tail) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, val, keylen);
map[bucket] = new Node(row_addr, word0, tail, val);
break;
}
if (node.contentEquals(row_addr, word0, tail)) {
if (node.contentEquals(row_addr, word0, tail, keylen)) {
node.add(val);
break;
}
@ -292,6 +301,7 @@ public class CalculateAverage_abeobk {
cls++;
}
}
if (SHOW_ANALYSIS) {
debug("Thread %d collision = %d", thread_id, cls);
}
@ -307,8 +317,6 @@ public class CalculateAverage_abeobk {
workerCommand.add("--worker");
new ProcessBuilder()
.command(workerCommand)
.inheritIO()
.redirectOutput(ProcessBuilder.Redirect.PIPE)
.start()
.getInputStream()
.transferTo(System.out);
@ -333,43 +341,29 @@ public class CalculateAverage_abeobk {
// processing
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
TreeMap<String, Node> ms = new TreeMap<>();
int[] lenhist = new int[64]; // length histogram
List<List<Node>> maps = IntStream.range(0, cpu_cnt)
List<List<Stat>> maps = IntStream.range(0, cpu_cnt)
.mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1]))
.map(map -> {
List<Node> nodes = new ArrayList<>();
List<Stat> stats = new ArrayList<>();
for (var node : map) {
if (node == null)
continue;
node.calcKey();
nodes.add(node);
stats.add(new Stat(node));
}
return nodes;
return stats;
})
.parallel()
.toList();
for (var nodes : maps) {
for (var node : nodes) {
if (SHOW_ANALYSIS) {
int kl = node.keylen & (lenhist.length - 1);
lenhist[kl] += node.count;
}
var stat = ms.putIfAbsent(node.key, node);
TreeMap<String, Stat> ms = new TreeMap<>();
for (var stats : maps) {
for (var s : stats) {
var stat = ms.putIfAbsent(s.key, s);
if (stat != null)
stat.merge(node);
stat.node.merge(s.node);
}
}
if (SHOW_ANALYSIS) {
debug("Total = " + Arrays.stream(lenhist).sum());
debug("Length_histogram = "
+ Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray()));
return;
}
// print result
System.out.println(ms);
System.out.close();