Simplify Node class with less field, improve hash mix speed (#584)
* Simplify Node class with less field, improve hash mix speed * remove some ops, a bit faster * more inline, little bit faster but not sure
This commit is contained in:
parent
ce9455a584
commit
271bdfb032
@ -20,6 +20,8 @@ sdk use java 21.0.2-graal 1>&2
|
||||
|
||||
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
|
||||
if [ ! -f target/CalculateAverage_abeobk_image ]; then
|
||||
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
|
||||
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -dsa -march=native -R:MaxHeapSize=128m -H:-GenLoopSafepoints -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
|
||||
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk
|
||||
fi
|
||||
|
||||
|
||||
|
@ -39,6 +39,7 @@ public class CalculateAverage_abeobk {
|
||||
private static final int BUCKET_SIZE = 1 << 16;
|
||||
private static final int BUCKET_MASK = BUCKET_SIZE - 1;
|
||||
private static final int MAX_STR_LEN = 100;
|
||||
private static final int MAX_STATIONS = 10000;
|
||||
private static final Unsafe UNSAFE = initUnsafe();
|
||||
private static final long[] HASH_MASKS = new long[]{
|
||||
0x0L,
|
||||
@ -66,6 +67,33 @@ public class CalculateAverage_abeobk {
|
||||
}
|
||||
}
|
||||
|
||||
static class Stat {
|
||||
Node node;
|
||||
String key;
|
||||
|
||||
public final String toString() {
|
||||
return (node.min / 10.0) + "/"
|
||||
+ (Math.round(((double) node.sum / node.count)) / 10.0) + "/"
|
||||
+ (node.max / 10.0);
|
||||
}
|
||||
|
||||
Stat(Node n) {
|
||||
node = n;
|
||||
byte[] sbuf = new byte[MAX_STR_LEN];
|
||||
long word = UNSAFE.getLong(n.addr);
|
||||
long semipos_code = getSemiPosCode(word);
|
||||
int keylen = 0;
|
||||
while (semipos_code == 0) {
|
||||
keylen += 8;
|
||||
word = UNSAFE.getLong(n.addr + keylen);
|
||||
semipos_code = getSemiPosCode(word);
|
||||
}
|
||||
keylen += Long.numberOfTrailingZeros(semipos_code) >>> 3;
|
||||
UNSAFE.copyMemory(null, n.addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
|
||||
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
|
||||
}
|
||||
}
|
||||
|
||||
static class Node {
|
||||
long addr;
|
||||
long word0;
|
||||
@ -73,37 +101,23 @@ public class CalculateAverage_abeobk {
|
||||
long sum;
|
||||
int count;
|
||||
short min, max;
|
||||
int keylen;
|
||||
String key;
|
||||
|
||||
void calcKey() {
|
||||
byte[] sbuf = new byte[MAX_STR_LEN];
|
||||
UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
|
||||
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1);
|
||||
}
|
||||
|
||||
Node(long a, long t, short val, int kl) {
|
||||
Node(long a, long t, short val) {
|
||||
addr = a;
|
||||
tail = t;
|
||||
keylen = kl;
|
||||
sum = min = max = val;
|
||||
count = 1;
|
||||
}
|
||||
|
||||
Node(long a, long w0, long t, short val, int kl) {
|
||||
Node(long a, long w0, long t, short val) {
|
||||
addr = a;
|
||||
word0 = w0;
|
||||
tail = t;
|
||||
keylen = kl;
|
||||
sum = min = max = val;
|
||||
count = 1;
|
||||
}
|
||||
|
||||
void add(short val) {
|
||||
final void add(short val) {
|
||||
sum += val;
|
||||
count++;
|
||||
if (val >= max) {
|
||||
@ -115,7 +129,7 @@ public class CalculateAverage_abeobk {
|
||||
}
|
||||
}
|
||||
|
||||
void merge(Node other) {
|
||||
final void merge(Node other) {
|
||||
sum += other.sum;
|
||||
count += other.count;
|
||||
if (other.max > max) {
|
||||
@ -126,8 +140,8 @@ public class CalculateAverage_abeobk {
|
||||
}
|
||||
}
|
||||
|
||||
boolean contentEquals(long other_addr, long other_word0, long other_tail) {
|
||||
if (tail != other_tail || word0 != other_word0)
|
||||
final boolean contentEquals(long other_addr, long other_word0, long other_tail, int keylen) {
|
||||
if (word0 != other_word0 || tail != other_tail)
|
||||
return false;
|
||||
// this is faster than comparision if key is short
|
||||
long xsum = 0;
|
||||
@ -161,11 +175,8 @@ public class CalculateAverage_abeobk {
|
||||
|
||||
// speed/collision balance
|
||||
static final int xxh32(long hash) {
|
||||
final int p1 = 0x85EBCA77; // prime
|
||||
int low = (int) hash;
|
||||
int high = (int) (hash >>> 33);
|
||||
int h = (low * p1) ^ high;
|
||||
return h ^ (h >>> 17);
|
||||
long h = hash * 37;
|
||||
return (int) (h ^ (h >>> 29));
|
||||
}
|
||||
|
||||
// great idea from merykitty (Quan Anh Mai)
|
||||
@ -185,11 +196,10 @@ public class CalculateAverage_abeobk {
|
||||
static final Node[] parse(int thread_id, long start, long end) {
|
||||
int cls = 0;
|
||||
long addr = start;
|
||||
var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
|
||||
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
|
||||
// parse loop
|
||||
while (addr < end) {
|
||||
long row_addr = addr;
|
||||
long hash = 0;
|
||||
|
||||
long word0 = UNSAFE.getLong(addr);
|
||||
long semipos_code = getSemiPosCode(word0);
|
||||
@ -202,14 +212,14 @@ public class CalculateAverage_abeobk {
|
||||
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
||||
addr += (dot_pos >>> 3) + 3;
|
||||
|
||||
long tail = (word0 & HASH_MASKS[semi_pos]);
|
||||
long tail = word0 & HASH_MASKS[semi_pos];
|
||||
int bucket = xxh32(tail) & BUCKET_MASK;
|
||||
short val = parseNum(num_word, dot_pos);
|
||||
|
||||
while (true) {
|
||||
var node = map[bucket];
|
||||
if (node == null) {
|
||||
map[bucket] = new Node(row_addr, tail, val, semi_pos);
|
||||
map[bucket] = new Node(row_addr, tail, val);
|
||||
break;
|
||||
}
|
||||
if (node.tail == tail) {
|
||||
@ -223,28 +233,25 @@ public class CalculateAverage_abeobk {
|
||||
continue;
|
||||
}
|
||||
|
||||
hash ^= word0;
|
||||
addr += 8;
|
||||
long word = UNSAFE.getLong(addr);
|
||||
semipos_code = getSemiPosCode(word);
|
||||
// 43% chance
|
||||
if (semipos_code != 0) {
|
||||
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
|
||||
addr += semi_pos;
|
||||
int keylen = (int) (addr - row_addr);
|
||||
long num_word = UNSAFE.getLong(addr + 1);
|
||||
addr += semi_pos + 1;
|
||||
long num_word = UNSAFE.getLong(addr);
|
||||
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
||||
addr += (dot_pos >>> 3) + 4;
|
||||
addr += (dot_pos >>> 3) + 3;
|
||||
|
||||
long tail = (word & HASH_MASKS[semi_pos]);
|
||||
hash ^= tail;
|
||||
int bucket = xxh32(hash) & BUCKET_MASK;
|
||||
int bucket = xxh32(word0 ^ tail) & BUCKET_MASK;
|
||||
short val = parseNum(num_word, dot_pos);
|
||||
|
||||
while (true) {
|
||||
var node = map[bucket];
|
||||
if (node == null) {
|
||||
map[bucket] = new Node(row_addr, word0, tail, val, keylen);
|
||||
map[bucket] = new Node(row_addr, word0, tail, val);
|
||||
break;
|
||||
}
|
||||
if (node.word0 == word0 && node.tail == tail) {
|
||||
@ -258,6 +265,9 @@ public class CalculateAverage_abeobk {
|
||||
continue;
|
||||
}
|
||||
|
||||
// why not going for more? tested, slower
|
||||
|
||||
long hash = word0;
|
||||
while (semipos_code == 0) {
|
||||
hash ^= word;
|
||||
addr += 8;
|
||||
@ -273,17 +283,16 @@ public class CalculateAverage_abeobk {
|
||||
addr += (dot_pos >>> 3) + 4;
|
||||
|
||||
long tail = (word & HASH_MASKS[semi_pos]);
|
||||
hash ^= tail;
|
||||
int bucket = xxh32(hash) & BUCKET_MASK;
|
||||
int bucket = xxh32(hash ^ tail) & BUCKET_MASK;
|
||||
short val = parseNum(num_word, dot_pos);
|
||||
|
||||
while (true) {
|
||||
var node = map[bucket];
|
||||
if (node == null) {
|
||||
map[bucket] = new Node(row_addr, word0, tail, val, keylen);
|
||||
map[bucket] = new Node(row_addr, word0, tail, val);
|
||||
break;
|
||||
}
|
||||
if (node.contentEquals(row_addr, word0, tail)) {
|
||||
if (node.contentEquals(row_addr, word0, tail, keylen)) {
|
||||
node.add(val);
|
||||
break;
|
||||
}
|
||||
@ -292,6 +301,7 @@ public class CalculateAverage_abeobk {
|
||||
cls++;
|
||||
}
|
||||
}
|
||||
|
||||
if (SHOW_ANALYSIS) {
|
||||
debug("Thread %d collision = %d", thread_id, cls);
|
||||
}
|
||||
@ -307,8 +317,6 @@ public class CalculateAverage_abeobk {
|
||||
workerCommand.add("--worker");
|
||||
new ProcessBuilder()
|
||||
.command(workerCommand)
|
||||
.inheritIO()
|
||||
.redirectOutput(ProcessBuilder.Redirect.PIPE)
|
||||
.start()
|
||||
.getInputStream()
|
||||
.transferTo(System.out);
|
||||
@ -333,43 +341,29 @@ public class CalculateAverage_abeobk {
|
||||
// processing
|
||||
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
|
||||
|
||||
TreeMap<String, Node> ms = new TreeMap<>();
|
||||
int[] lenhist = new int[64]; // length histogram
|
||||
|
||||
List<List<Node>> maps = IntStream.range(0, cpu_cnt)
|
||||
List<List<Stat>> maps = IntStream.range(0, cpu_cnt)
|
||||
.mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1]))
|
||||
.map(map -> {
|
||||
List<Node> nodes = new ArrayList<>();
|
||||
List<Stat> stats = new ArrayList<>();
|
||||
for (var node : map) {
|
||||
if (node == null)
|
||||
continue;
|
||||
node.calcKey();
|
||||
nodes.add(node);
|
||||
stats.add(new Stat(node));
|
||||
}
|
||||
return nodes;
|
||||
return stats;
|
||||
})
|
||||
.parallel()
|
||||
.toList();
|
||||
|
||||
for (var nodes : maps) {
|
||||
for (var node : nodes) {
|
||||
if (SHOW_ANALYSIS) {
|
||||
int kl = node.keylen & (lenhist.length - 1);
|
||||
lenhist[kl] += node.count;
|
||||
}
|
||||
var stat = ms.putIfAbsent(node.key, node);
|
||||
TreeMap<String, Stat> ms = new TreeMap<>();
|
||||
for (var stats : maps) {
|
||||
for (var s : stats) {
|
||||
var stat = ms.putIfAbsent(s.key, s);
|
||||
if (stat != null)
|
||||
stat.merge(node);
|
||||
stat.node.merge(s.node);
|
||||
}
|
||||
}
|
||||
|
||||
if (SHOW_ANALYSIS) {
|
||||
debug("Total = " + Arrays.stream(lenhist).sum());
|
||||
debug("Length_histogram = "
|
||||
+ Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray()));
|
||||
return;
|
||||
}
|
||||
|
||||
// print result
|
||||
System.out.println(ms);
|
||||
System.out.close();
|
||||
|
Loading…
Reference in New Issue
Block a user