Simplify Node class with less field, improve hash mix speed (#584)
* Simplify Node class with less field, improve hash mix speed * remove some ops, a bit faster * more inline, little bit faster but not sure
This commit is contained in:
parent
ce9455a584
commit
271bdfb032
@ -20,6 +20,8 @@ sdk use java 21.0.2-graal 1>&2
|
|||||||
|
|
||||||
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
|
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
|
||||||
if [ ! -f target/CalculateAverage_abeobk_image ]; then
|
if [ ! -f target/CalculateAverage_abeobk_image ]; then
|
||||||
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
|
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -dsa -march=native -R:MaxHeapSize=128m -H:-GenLoopSafepoints -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
|
||||||
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk
|
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@ public class CalculateAverage_abeobk {
|
|||||||
private static final int BUCKET_SIZE = 1 << 16;
|
private static final int BUCKET_SIZE = 1 << 16;
|
||||||
private static final int BUCKET_MASK = BUCKET_SIZE - 1;
|
private static final int BUCKET_MASK = BUCKET_SIZE - 1;
|
||||||
private static final int MAX_STR_LEN = 100;
|
private static final int MAX_STR_LEN = 100;
|
||||||
|
private static final int MAX_STATIONS = 10000;
|
||||||
private static final Unsafe UNSAFE = initUnsafe();
|
private static final Unsafe UNSAFE = initUnsafe();
|
||||||
private static final long[] HASH_MASKS = new long[]{
|
private static final long[] HASH_MASKS = new long[]{
|
||||||
0x0L,
|
0x0L,
|
||||||
@ -66,6 +67,33 @@ public class CalculateAverage_abeobk {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static class Stat {
|
||||||
|
Node node;
|
||||||
|
String key;
|
||||||
|
|
||||||
|
public final String toString() {
|
||||||
|
return (node.min / 10.0) + "/"
|
||||||
|
+ (Math.round(((double) node.sum / node.count)) / 10.0) + "/"
|
||||||
|
+ (node.max / 10.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Stat(Node n) {
|
||||||
|
node = n;
|
||||||
|
byte[] sbuf = new byte[MAX_STR_LEN];
|
||||||
|
long word = UNSAFE.getLong(n.addr);
|
||||||
|
long semipos_code = getSemiPosCode(word);
|
||||||
|
int keylen = 0;
|
||||||
|
while (semipos_code == 0) {
|
||||||
|
keylen += 8;
|
||||||
|
word = UNSAFE.getLong(n.addr + keylen);
|
||||||
|
semipos_code = getSemiPosCode(word);
|
||||||
|
}
|
||||||
|
keylen += Long.numberOfTrailingZeros(semipos_code) >>> 3;
|
||||||
|
UNSAFE.copyMemory(null, n.addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
|
||||||
|
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static class Node {
|
static class Node {
|
||||||
long addr;
|
long addr;
|
||||||
long word0;
|
long word0;
|
||||||
@ -73,37 +101,23 @@ public class CalculateAverage_abeobk {
|
|||||||
long sum;
|
long sum;
|
||||||
int count;
|
int count;
|
||||||
short min, max;
|
short min, max;
|
||||||
int keylen;
|
|
||||||
String key;
|
|
||||||
|
|
||||||
void calcKey() {
|
Node(long a, long t, short val) {
|
||||||
byte[] sbuf = new byte[MAX_STR_LEN];
|
|
||||||
UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
|
|
||||||
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
|
|
||||||
}
|
|
||||||
|
|
||||||
public String toString() {
|
|
||||||
return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
Node(long a, long t, short val, int kl) {
|
|
||||||
addr = a;
|
addr = a;
|
||||||
tail = t;
|
tail = t;
|
||||||
keylen = kl;
|
|
||||||
sum = min = max = val;
|
sum = min = max = val;
|
||||||
count = 1;
|
count = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Node(long a, long w0, long t, short val, int kl) {
|
Node(long a, long w0, long t, short val) {
|
||||||
addr = a;
|
addr = a;
|
||||||
word0 = w0;
|
word0 = w0;
|
||||||
tail = t;
|
tail = t;
|
||||||
keylen = kl;
|
|
||||||
sum = min = max = val;
|
sum = min = max = val;
|
||||||
count = 1;
|
count = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
void add(short val) {
|
final void add(short val) {
|
||||||
sum += val;
|
sum += val;
|
||||||
count++;
|
count++;
|
||||||
if (val >= max) {
|
if (val >= max) {
|
||||||
@ -115,7 +129,7 @@ public class CalculateAverage_abeobk {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void merge(Node other) {
|
final void merge(Node other) {
|
||||||
sum += other.sum;
|
sum += other.sum;
|
||||||
count += other.count;
|
count += other.count;
|
||||||
if (other.max > max) {
|
if (other.max > max) {
|
||||||
@ -126,8 +140,8 @@ public class CalculateAverage_abeobk {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean contentEquals(long other_addr, long other_word0, long other_tail) {
|
final boolean contentEquals(long other_addr, long other_word0, long other_tail, int keylen) {
|
||||||
if (tail != other_tail || word0 != other_word0)
|
if (word0 != other_word0 || tail != other_tail)
|
||||||
return false;
|
return false;
|
||||||
// this is faster than comparision if key is short
|
// this is faster than comparision if key is short
|
||||||
long xsum = 0;
|
long xsum = 0;
|
||||||
@ -161,11 +175,8 @@ public class CalculateAverage_abeobk {
|
|||||||
|
|
||||||
// speed/collision balance
|
// speed/collision balance
|
||||||
static final int xxh32(long hash) {
|
static final int xxh32(long hash) {
|
||||||
final int p1 = 0x85EBCA77; // prime
|
long h = hash * 37;
|
||||||
int low = (int) hash;
|
return (int) (h ^ (h >>> 29));
|
||||||
int high = (int) (hash >>> 33);
|
|
||||||
int h = (low * p1) ^ high;
|
|
||||||
return h ^ (h >>> 17);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// great idea from merykitty (Quan Anh Mai)
|
// great idea from merykitty (Quan Anh Mai)
|
||||||
@ -185,11 +196,10 @@ public class CalculateAverage_abeobk {
|
|||||||
static final Node[] parse(int thread_id, long start, long end) {
|
static final Node[] parse(int thread_id, long start, long end) {
|
||||||
int cls = 0;
|
int cls = 0;
|
||||||
long addr = start;
|
long addr = start;
|
||||||
var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
|
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
|
||||||
// parse loop
|
// parse loop
|
||||||
while (addr < end) {
|
while (addr < end) {
|
||||||
long row_addr = addr;
|
long row_addr = addr;
|
||||||
long hash = 0;
|
|
||||||
|
|
||||||
long word0 = UNSAFE.getLong(addr);
|
long word0 = UNSAFE.getLong(addr);
|
||||||
long semipos_code = getSemiPosCode(word0);
|
long semipos_code = getSemiPosCode(word0);
|
||||||
@ -202,14 +212,14 @@ public class CalculateAverage_abeobk {
|
|||||||
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
||||||
addr += (dot_pos >>> 3) + 3;
|
addr += (dot_pos >>> 3) + 3;
|
||||||
|
|
||||||
long tail = (word0 & HASH_MASKS[semi_pos]);
|
long tail = word0 & HASH_MASKS[semi_pos];
|
||||||
int bucket = xxh32(tail) & BUCKET_MASK;
|
int bucket = xxh32(tail) & BUCKET_MASK;
|
||||||
short val = parseNum(num_word, dot_pos);
|
short val = parseNum(num_word, dot_pos);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
var node = map[bucket];
|
var node = map[bucket];
|
||||||
if (node == null) {
|
if (node == null) {
|
||||||
map[bucket] = new Node(row_addr, tail, val, semi_pos);
|
map[bucket] = new Node(row_addr, tail, val);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (node.tail == tail) {
|
if (node.tail == tail) {
|
||||||
@ -223,28 +233,25 @@ public class CalculateAverage_abeobk {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
hash ^= word0;
|
|
||||||
addr += 8;
|
addr += 8;
|
||||||
long word = UNSAFE.getLong(addr);
|
long word = UNSAFE.getLong(addr);
|
||||||
semipos_code = getSemiPosCode(word);
|
semipos_code = getSemiPosCode(word);
|
||||||
// 43% chance
|
// 43% chance
|
||||||
if (semipos_code != 0) {
|
if (semipos_code != 0) {
|
||||||
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
|
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
|
||||||
addr += semi_pos;
|
addr += semi_pos + 1;
|
||||||
int keylen = (int) (addr - row_addr);
|
long num_word = UNSAFE.getLong(addr);
|
||||||
long num_word = UNSAFE.getLong(addr + 1);
|
|
||||||
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
||||||
addr += (dot_pos >>> 3) + 4;
|
addr += (dot_pos >>> 3) + 3;
|
||||||
|
|
||||||
long tail = (word & HASH_MASKS[semi_pos]);
|
long tail = (word & HASH_MASKS[semi_pos]);
|
||||||
hash ^= tail;
|
int bucket = xxh32(word0 ^ tail) & BUCKET_MASK;
|
||||||
int bucket = xxh32(hash) & BUCKET_MASK;
|
|
||||||
short val = parseNum(num_word, dot_pos);
|
short val = parseNum(num_word, dot_pos);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
var node = map[bucket];
|
var node = map[bucket];
|
||||||
if (node == null) {
|
if (node == null) {
|
||||||
map[bucket] = new Node(row_addr, word0, tail, val, keylen);
|
map[bucket] = new Node(row_addr, word0, tail, val);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (node.word0 == word0 && node.tail == tail) {
|
if (node.word0 == word0 && node.tail == tail) {
|
||||||
@ -258,6 +265,9 @@ public class CalculateAverage_abeobk {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// why not going for more? tested, slower
|
||||||
|
|
||||||
|
long hash = word0;
|
||||||
while (semipos_code == 0) {
|
while (semipos_code == 0) {
|
||||||
hash ^= word;
|
hash ^= word;
|
||||||
addr += 8;
|
addr += 8;
|
||||||
@ -273,17 +283,16 @@ public class CalculateAverage_abeobk {
|
|||||||
addr += (dot_pos >>> 3) + 4;
|
addr += (dot_pos >>> 3) + 4;
|
||||||
|
|
||||||
long tail = (word & HASH_MASKS[semi_pos]);
|
long tail = (word & HASH_MASKS[semi_pos]);
|
||||||
hash ^= tail;
|
int bucket = xxh32(hash ^ tail) & BUCKET_MASK;
|
||||||
int bucket = xxh32(hash) & BUCKET_MASK;
|
|
||||||
short val = parseNum(num_word, dot_pos);
|
short val = parseNum(num_word, dot_pos);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
var node = map[bucket];
|
var node = map[bucket];
|
||||||
if (node == null) {
|
if (node == null) {
|
||||||
map[bucket] = new Node(row_addr, word0, tail, val, keylen);
|
map[bucket] = new Node(row_addr, word0, tail, val);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (node.contentEquals(row_addr, word0, tail)) {
|
if (node.contentEquals(row_addr, word0, tail, keylen)) {
|
||||||
node.add(val);
|
node.add(val);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -292,6 +301,7 @@ public class CalculateAverage_abeobk {
|
|||||||
cls++;
|
cls++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (SHOW_ANALYSIS) {
|
if (SHOW_ANALYSIS) {
|
||||||
debug("Thread %d collision = %d", thread_id, cls);
|
debug("Thread %d collision = %d", thread_id, cls);
|
||||||
}
|
}
|
||||||
@ -307,8 +317,6 @@ public class CalculateAverage_abeobk {
|
|||||||
workerCommand.add("--worker");
|
workerCommand.add("--worker");
|
||||||
new ProcessBuilder()
|
new ProcessBuilder()
|
||||||
.command(workerCommand)
|
.command(workerCommand)
|
||||||
.inheritIO()
|
|
||||||
.redirectOutput(ProcessBuilder.Redirect.PIPE)
|
|
||||||
.start()
|
.start()
|
||||||
.getInputStream()
|
.getInputStream()
|
||||||
.transferTo(System.out);
|
.transferTo(System.out);
|
||||||
@ -333,43 +341,29 @@ public class CalculateAverage_abeobk {
|
|||||||
// processing
|
// processing
|
||||||
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
|
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
|
||||||
|
|
||||||
TreeMap<String, Node> ms = new TreeMap<>();
|
List<List<Stat>> maps = IntStream.range(0, cpu_cnt)
|
||||||
int[] lenhist = new int[64]; // length histogram
|
|
||||||
|
|
||||||
List<List<Node>> maps = IntStream.range(0, cpu_cnt)
|
|
||||||
.mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1]))
|
.mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1]))
|
||||||
.map(map -> {
|
.map(map -> {
|
||||||
List<Node> nodes = new ArrayList<>();
|
List<Stat> stats = new ArrayList<>();
|
||||||
for (var node : map) {
|
for (var node : map) {
|
||||||
if (node == null)
|
if (node == null)
|
||||||
continue;
|
continue;
|
||||||
node.calcKey();
|
stats.add(new Stat(node));
|
||||||
nodes.add(node);
|
|
||||||
}
|
}
|
||||||
return nodes;
|
return stats;
|
||||||
})
|
})
|
||||||
.parallel()
|
.parallel()
|
||||||
.toList();
|
.toList();
|
||||||
|
|
||||||
for (var nodes : maps) {
|
TreeMap<String, Stat> ms = new TreeMap<>();
|
||||||
for (var node : nodes) {
|
for (var stats : maps) {
|
||||||
if (SHOW_ANALYSIS) {
|
for (var s : stats) {
|
||||||
int kl = node.keylen & (lenhist.length - 1);
|
var stat = ms.putIfAbsent(s.key, s);
|
||||||
lenhist[kl] += node.count;
|
|
||||||
}
|
|
||||||
var stat = ms.putIfAbsent(node.key, node);
|
|
||||||
if (stat != null)
|
if (stat != null)
|
||||||
stat.merge(node);
|
stat.node.merge(s.node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (SHOW_ANALYSIS) {
|
|
||||||
debug("Total = " + Arrays.stream(lenhist).sum());
|
|
||||||
debug("Length_histogram = "
|
|
||||||
+ Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray()));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// print result
|
// print result
|
||||||
System.out.println(ms);
|
System.out.println(ms);
|
||||||
System.out.close();
|
System.out.close();
|
||||||
|
Loading…
Reference in New Issue
Block a user