apply loop unroll trick (#643)

* apply loop unroll trick

* less assign op, a bit faster
This commit is contained in:
Van Phu DO 2024-01-30 05:21:04 +09:00 committed by GitHub
parent 31a6740ef1
commit 8e407ca79d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 179 additions and 129 deletions

View File

@ -20,8 +20,6 @@ sdk use java 21.0.2-graal 1>&2
# ./mvnw clean verify removes target/ and will re-trigger native image creation. # ./mvnw clean verify removes target/ and will re-trigger native image creation.
if [ ! -f target/CalculateAverage_abeobk_image ]; then if [ ! -f target/CalculateAverage_abeobk_image ]; then
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -dsa -march=native -R:MaxHeapSize=128m -H:-GenLoopSafepoints -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -dsa -march=native -H:InlineAllBonus=10 -H:-GenLoopSafepoints -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk
fi fi

View File

@ -98,21 +98,21 @@ public class CalculateAverage_abeobk {
return new String(sbuf, 0, (int) keylen, StandardCharsets.UTF_8); return new String(sbuf, 0, (int) keylen, StandardCharsets.UTF_8);
} }
Node(long a, long t, int kl, long h, long val) { Node(long a, long t, int kl, long h) {
addr = a; addr = a;
tail = t; tail = t;
sum = min = max = val; min = 999;
count = 1; max = -999;
keylen = kl; keylen = kl;
hash = h; hash = h;
} }
Node(long a, long w0, long t, int kl, long h, long val) { Node(long a, long w0, long t, int kl, long h) {
addr = a; addr = a;
word0 = w0; word0 = w0;
min = 999;
max = -999;
tail = t; tail = t;
sum = min = max = val;
count = 1;
keylen = kl; keylen = kl;
hash = h; hash = h;
} }
@ -120,9 +120,8 @@ public class CalculateAverage_abeobk {
final void add(long val) { final void add(long val) {
sum += val; sum += val;
count++; count++;
if (val >= max) { if (val > max) {
max = val; max = val;
return;
} }
if (val < min) { if (val < min) {
min = val; min = val;
@ -170,14 +169,52 @@ public class CalculateAverage_abeobk {
return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L); return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L);
} }
static final long getLFCode(final long word) {
long xor_semi = word ^ 0x0A0A0A0A0A0A0A0AL; // xor with \n\n\n\n\n\n\n\n
return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L);
}
static final long nextLine(long addr) {
long word = UNSAFE.getLong(addr);
long lfpos_code = getLFCode(word);
while (lfpos_code == 0) {
addr += 8;
word = UNSAFE.getLong(addr);
lfpos_code = getLFCode(word);
}
return addr + (Long.numberOfTrailingZeros(lfpos_code) >>> 3) + 1;
}
// speed/collision balance // speed/collision balance
static final long xxh32(long hash) { static final long xxh32(long hash) {
long h = hash * 37; long h = hash * 37;
return (h ^ (h >>> 29)); return (h ^ (h >>> 29));
} }
static final class ChunkParser {
long addr;
long end;
Node[] map;
ChunkParser(Node[] m, long a, long e) {
map = m;
addr = a;
end = e;
}
final boolean ok() {
return addr < end;
}
final long word() {
return UNSAFE.getLong(addr);
}
final long val() {
long num_word = UNSAFE.getLong(addr);
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 3;
// great idea from merykitty (Quan Anh Mai) // great idea from merykitty (Quan Anh Mai)
static final long parseNum(long num_word, int dot_pos) {
int shift = 28 - dot_pos; int shift = 28 - dot_pos;
long signed = (~num_word << 59) >> 63; long signed = (~num_word << 59) >> 63;
long dsmask = ~(signed & 0xFF); long dsmask = ~(signed & 0xFF);
@ -186,71 +223,29 @@ public class CalculateAverage_abeobk {
return ((abs_val ^ signed) - signed); return ((abs_val ^ signed) - signed);
} }
// Thread pool worker
static final class Worker extends Thread {
final int thread_id; // for debug use only
Worker(int i) {
thread_id = i;
this.start();
}
@Override
public void run() {
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
int id;
int cls = 0;
// process in small chunk to maintain disk locality (artsiomkorzun trick)
while ((id = chunk_id.getAndIncrement()) < chunk_cnt) {
long addr = start_addr + id * CHUNK_SZ;
long end = Math.min(addr + CHUNK_SZ, end_addr);
// find start of line
if (id > 0) {
while (UNSAFE.getByte(addr++) != '\n')
;
}
// parse loop
// optimize for contest // optimize for contest
// save as much slow memory access as possible // save as much slow memory access as possible
// about 50% key < 8chars, 25% key bettween 8-10 chars // about 50% key < 8chars, 25% key bettween 8-10 chars
// keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
while (addr < end) { final Node key(long word0, long semipos_code) {
long row_addr = addr; long row_addr = addr;
long word0 = UNSAFE.getLong(addr);
long semipos_code = getSemiPosCode(word0);
// about 50% chance key < 8 chars // about 50% chance key < 8 chars
if (semipos_code != 0) { if (semipos_code != 0) {
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
addr += semi_pos + 1; addr += semi_pos + 1;
long num_word = UNSAFE.getLong(addr);
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 3;
long tail = word0 & HASH_MASKS[semi_pos]; long tail = word0 & HASH_MASKS[semi_pos];
long hash = xxh32(tail); long hash = xxh32(tail);
int bucket = (int) (hash & BUCKET_MASK); int bucket = (int) (hash & BUCKET_MASK);
long val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; Node node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, tail, semi_pos, hash, val); return (map[bucket] = new Node(row_addr, tail, semi_pos, hash));
break;
} }
if (node.tail == tail) { if (node.tail == tail) {
node.add(val); return node;
break;
} }
bucket++; bucket++;
if (SHOW_ANALYSIS)
cls++;
} }
continue;
} }
addr += 8; addr += 8;
@ -260,30 +255,19 @@ public class CalculateAverage_abeobk {
if (semipos_code != 0) { if (semipos_code != 0) {
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
addr += semi_pos + 1; addr += semi_pos + 1;
long num_word = UNSAFE.getLong(addr);
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 3;
long tail = (word & HASH_MASKS[semi_pos]); long tail = (word & HASH_MASKS[semi_pos]);
long hash = xxh32(word0 ^ tail); long hash = xxh32(word0 ^ tail);
int bucket = (int) (hash & BUCKET_MASK); int bucket = (int) (hash & BUCKET_MASK);
long val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; Node node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash, val); return (map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash));
break;
} }
if (node.word0 == word0 && node.tail == tail) { if (node.word0 == word0 && node.tail == tail) {
node.add(val); return node;
break;
} }
bucket++; bucket++;
if (SHOW_ANALYSIS)
cls++;
} }
continue;
} }
// why not going for more? tested, slower // why not going for more? tested, slower
@ -298,32 +282,100 @@ public class CalculateAverage_abeobk {
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3; int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
addr += semi_pos; addr += semi_pos;
long keylen = addr - row_addr; long keylen = addr - row_addr;
long num_word = UNSAFE.getLong(addr + 1); addr++;
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 4;
long tail = (word & HASH_MASKS[semi_pos]); long tail = (word & HASH_MASKS[semi_pos]);
hash = xxh32(hash ^ tail); hash = xxh32(hash ^ tail);
int bucket = (int) (hash & BUCKET_MASK); int bucket = (int) (hash & BUCKET_MASK);
long val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; Node node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash, val); return (map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash));
break;
} }
if (node.contentEquals(row_addr, word0, tail, keylen)) { if (node.contentEquals(row_addr, word0, tail, keylen)) {
node.add(val); return node;
break;
} }
bucket++; bucket++;
if (SHOW_ANALYSIS)
cls++;
} }
} }
} }
// Thread pool worker
static final class Worker extends Thread {
final int thread_id; // for debug use only
int cls = 0;
Worker(int i) {
thread_id = i;
this.start();
}
@Override
public void run() {
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
int id;
// process in small chunk to maintain disk locality (artsiomkorzun trick)
while ((id = chunk_id.getAndIncrement()) < chunk_cnt) {
long addr = start_addr + id * CHUNK_SZ;
long end = Math.min(addr + CHUNK_SZ, end_addr);
// find start of line
if (id > 0) {
addr = nextLine(addr);
}
final int num_segs = 3;
long seglen = (end - addr) / num_segs;
long a0 = addr;
long a1 = nextLine(addr + 1 * seglen);
long a2 = nextLine(addr + 2 * seglen);
ChunkParser p0 = new ChunkParser(map, a0, a1);
ChunkParser p1 = new ChunkParser(map, a1, a2);
ChunkParser p2 = new ChunkParser(map, a2, end);
while (p0.ok() && p1.ok() && p2.ok()) {
long w0 = p0.word();
long w1 = p1.word();
long w2 = p2.word();
long sc0 = getSemiPosCode(w0);
long sc1 = getSemiPosCode(w1);
long sc2 = getSemiPosCode(w2);
Node n0 = p0.key(w0, sc0);
Node n1 = p1.key(w1, sc1);
Node n2 = p2.key(w2, sc2);
long v0 = p0.val();
long v1 = p1.val();
long v2 = p2.val();
n0.add(v0);
n1.add(v1);
n2.add(v2);
}
while (p0.ok()) {
long w = p0.word();
long sc = getSemiPosCode(w);
Node n = p0.key(w, sc);
long v = p0.val();
n.add(v);
}
while (p1.ok()) {
long w = p1.word();
long sc = getSemiPosCode(w);
Node n = p1.key(w, sc);
long v = p1.val();
n.add(v);
}
while (p2.ok()) {
long w = p2.word();
long sc = getSemiPosCode(w);
Node n = p2.key(w, sc);
long v = p2.val();
n.add(v);
}
}
// merge is cheaper than string casting (artsiomkorzun) // merge is cheaper than string casting (artsiomkorzun)
while (!mapref.compareAndSet(null, map)) { while (!mapref.compareAndSet(null, map)) {
var other_map = mapref.getAndSet(null); var other_map = mapref.getAndSet(null);