Simplify Node class with less field, improve hash mix speed (#584)

* Simplify Node class with less field, improve hash mix speed

* remove some ops, a bit faster

* more inline, little bit faster but not sure
This commit is contained in:
Van Phu DO 2024-01-26 06:57:04 +09:00 committed by GitHub
parent ce9455a584
commit 271bdfb032
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 67 deletions

View File

@ -20,6 +20,8 @@ sdk use java 21.0.2-graal 1>&2
# ./mvnw clean verify removes target/ and will re-trigger native image creation. # ./mvnw clean verify removes target/ and will re-trigger native image creation.
if [ ! -f target/CalculateAverage_abeobk_image ]; then if [ ! -f target/CalculateAverage_abeobk_image ]; then
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -dsa -march=native -R:MaxHeapSize=128m -H:-GenLoopSafepoints -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk
fi fi

View File

@ -39,6 +39,7 @@ public class CalculateAverage_abeobk {
private static final int BUCKET_SIZE = 1 << 16; private static final int BUCKET_SIZE = 1 << 16;
private static final int BUCKET_MASK = BUCKET_SIZE - 1; private static final int BUCKET_MASK = BUCKET_SIZE - 1;
private static final int MAX_STR_LEN = 100; private static final int MAX_STR_LEN = 100;
private static final int MAX_STATIONS = 10000;
private static final Unsafe UNSAFE = initUnsafe(); private static final Unsafe UNSAFE = initUnsafe();
private static final long[] HASH_MASKS = new long[]{ private static final long[] HASH_MASKS = new long[]{
0x0L, 0x0L,
@ -66,6 +67,33 @@ public class CalculateAverage_abeobk {
} }
} }
static class Stat {
Node node;
String key;
public final String toString() {
return (node.min / 10.0) + "/"
+ (Math.round(((double) node.sum / node.count)) / 10.0) + "/"
+ (node.max / 10.0);
}
Stat(Node n) {
node = n;
byte[] sbuf = new byte[MAX_STR_LEN];
long word = UNSAFE.getLong(n.addr);
long semipos_code = getSemiPosCode(word);
int keylen = 0;
while (semipos_code == 0) {
keylen += 8;
word = UNSAFE.getLong(n.addr + keylen);
semipos_code = getSemiPosCode(word);
}
keylen += Long.numberOfTrailingZeros(semipos_code) >>> 3;
UNSAFE.copyMemory(null, n.addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
}
}
static class Node { static class Node {
long addr; long addr;
long word0; long word0;
@ -73,37 +101,23 @@ public class CalculateAverage_abeobk {
long sum; long sum;
int count; int count;
short min, max; short min, max;
int keylen;
String key;
void calcKey() { Node(long a, long t, short val) {
byte[] sbuf = new byte[MAX_STR_LEN];
UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
}
public String toString() {
return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1);
}
Node(long a, long t, short val, int kl) {
addr = a; addr = a;
tail = t; tail = t;
keylen = kl;
sum = min = max = val; sum = min = max = val;
count = 1; count = 1;
} }
Node(long a, long w0, long t, short val, int kl) { Node(long a, long w0, long t, short val) {
addr = a; addr = a;
word0 = w0; word0 = w0;
tail = t; tail = t;
keylen = kl;
sum = min = max = val; sum = min = max = val;
count = 1; count = 1;
} }
void add(short val) { final void add(short val) {
sum += val; sum += val;
count++; count++;
if (val >= max) { if (val >= max) {
@ -115,7 +129,7 @@ public class CalculateAverage_abeobk {
} }
} }
void merge(Node other) { final void merge(Node other) {
sum += other.sum; sum += other.sum;
count += other.count; count += other.count;
if (other.max > max) { if (other.max > max) {
@ -126,8 +140,8 @@ public class CalculateAverage_abeobk {
} }
} }
boolean contentEquals(long other_addr, long other_word0, long other_tail) { final boolean contentEquals(long other_addr, long other_word0, long other_tail, int keylen) {
if (tail != other_tail || word0 != other_word0) if (word0 != other_word0 || tail != other_tail)
return false; return false;
// this is faster than comparision if key is short // this is faster than comparision if key is short
long xsum = 0; long xsum = 0;
@ -161,11 +175,8 @@ public class CalculateAverage_abeobk {
// speed/collision balance // speed/collision balance
static final int xxh32(long hash) { static final int xxh32(long hash) {
final int p1 = 0x85EBCA77; // prime long h = hash * 37;
int low = (int) hash; return (int) (h ^ (h >>> 29));
int high = (int) (hash >>> 33);
int h = (low * p1) ^ high;
return h ^ (h >>> 17);
} }
// great idea from merykitty (Quan Anh Mai) // great idea from merykitty (Quan Anh Mai)
@ -185,11 +196,10 @@ public class CalculateAverage_abeobk {
static final Node[] parse(int thread_id, long start, long end) { static final Node[] parse(int thread_id, long start, long end) {
int cls = 0; int cls = 0;
long addr = start; long addr = start;
var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
// parse loop // parse loop
while (addr < end) { while (addr < end) {
long row_addr = addr; long row_addr = addr;
long hash = 0;
long word0 = UNSAFE.getLong(addr); long word0 = UNSAFE.getLong(addr);
long semipos_code = getSemiPosCode(word0); long semipos_code = getSemiPosCode(word0);
@ -202,14 +212,14 @@ public class CalculateAverage_abeobk {
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 3; addr += (dot_pos >>> 3) + 3;
long tail = (word0 & HASH_MASKS[semi_pos]); long tail = word0 & HASH_MASKS[semi_pos];
int bucket = xxh32(tail) & BUCKET_MASK; int bucket = xxh32(tail) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos); short val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, tail, val, semi_pos); map[bucket] = new Node(row_addr, tail, val);
break; break;
} }
if (node.tail == tail) { if (node.tail == tail) {
@ -223,28 +233,25 @@ public class CalculateAverage_abeobk {
continue; continue;
} }
hash ^= word0;
addr += 8; addr += 8;
long word = UNSAFE.getLong(addr); long word = UNSAFE.getLong(addr);
semipos_code = getSemiPosCode(word); semipos_code = getSemiPosCode(word);
// 43% chance // 43% chance
if (semipos_code != 0) { if (semipos_code != 0) {
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
addr += semi_pos; addr += semi_pos + 1;
int keylen = (int) (addr - row_addr); long num_word = UNSAFE.getLong(addr);
long num_word = UNSAFE.getLong(addr + 1);
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 4; addr += (dot_pos >>> 3) + 3;
long tail = (word & HASH_MASKS[semi_pos]); long tail = (word & HASH_MASKS[semi_pos]);
hash ^= tail; int bucket = xxh32(word0 ^ tail) & BUCKET_MASK;
int bucket = xxh32(hash) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos); short val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, val, keylen); map[bucket] = new Node(row_addr, word0, tail, val);
break; break;
} }
if (node.word0 == word0 && node.tail == tail) { if (node.word0 == word0 && node.tail == tail) {
@ -258,6 +265,9 @@ public class CalculateAverage_abeobk {
continue; continue;
} }
// why not going for more? tested, slower
long hash = word0;
while (semipos_code == 0) { while (semipos_code == 0) {
hash ^= word; hash ^= word;
addr += 8; addr += 8;
@ -273,17 +283,16 @@ public class CalculateAverage_abeobk {
addr += (dot_pos >>> 3) + 4; addr += (dot_pos >>> 3) + 4;
long tail = (word & HASH_MASKS[semi_pos]); long tail = (word & HASH_MASKS[semi_pos]);
hash ^= tail; int bucket = xxh32(hash ^ tail) & BUCKET_MASK;
int bucket = xxh32(hash) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos); short val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, val, keylen); map[bucket] = new Node(row_addr, word0, tail, val);
break; break;
} }
if (node.contentEquals(row_addr, word0, tail)) { if (node.contentEquals(row_addr, word0, tail, keylen)) {
node.add(val); node.add(val);
break; break;
} }
@ -292,6 +301,7 @@ public class CalculateAverage_abeobk {
cls++; cls++;
} }
} }
if (SHOW_ANALYSIS) { if (SHOW_ANALYSIS) {
debug("Thread %d collision = %d", thread_id, cls); debug("Thread %d collision = %d", thread_id, cls);
} }
@ -307,8 +317,6 @@ public class CalculateAverage_abeobk {
workerCommand.add("--worker"); workerCommand.add("--worker");
new ProcessBuilder() new ProcessBuilder()
.command(workerCommand) .command(workerCommand)
.inheritIO()
.redirectOutput(ProcessBuilder.Redirect.PIPE)
.start() .start()
.getInputStream() .getInputStream()
.transferTo(System.out); .transferTo(System.out);
@ -333,43 +341,29 @@ public class CalculateAverage_abeobk {
// processing // processing
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
TreeMap<String, Node> ms = new TreeMap<>(); List<List<Stat>> maps = IntStream.range(0, cpu_cnt)
int[] lenhist = new int[64]; // length histogram
List<List<Node>> maps = IntStream.range(0, cpu_cnt)
.mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1])) .mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1]))
.map(map -> { .map(map -> {
List<Node> nodes = new ArrayList<>(); List<Stat> stats = new ArrayList<>();
for (var node : map) { for (var node : map) {
if (node == null) if (node == null)
continue; continue;
node.calcKey(); stats.add(new Stat(node));
nodes.add(node);
} }
return nodes; return stats;
}) })
.parallel() .parallel()
.toList(); .toList();
for (var nodes : maps) { TreeMap<String, Stat> ms = new TreeMap<>();
for (var node : nodes) { for (var stats : maps) {
if (SHOW_ANALYSIS) { for (var s : stats) {
int kl = node.keylen & (lenhist.length - 1); var stat = ms.putIfAbsent(s.key, s);
lenhist[kl] += node.count;
}
var stat = ms.putIfAbsent(node.key, node);
if (stat != null) if (stat != null)
stat.merge(node); stat.node.merge(s.node);
} }
} }
if (SHOW_ANALYSIS) {
debug("Total = " + Arrays.stream(lenhist).sum());
debug("Length_histogram = "
+ Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray()));
return;
}
// print result // print result
System.out.println(ms); System.out.println(ms);
System.out.close(); System.out.close();