CalculateAverage_gonix update (#706)

Backported some of the optimizations from unsafe solution.

Co-authored-by: Giedrius D <d.giedrius@gmail.com>
This commit is contained in:
gonix 2024-02-01 12:53:46 +02:00 committed by GitHub
parent fdd539e1f9
commit 1e7314d5fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 320 additions and 209 deletions

View File

@ -1,4 +1,4 @@
#!/bin/sh #!/bin/bash
# #
# Copyright 2023 The original authors # Copyright 2023 The original authors
# #
@ -17,4 +17,4 @@
JAVA_OPTS="--enable-preview" JAVA_OPTS="--enable-preview"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_gonix exec cat < <(exec java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_gonix)

View File

@ -46,6 +46,7 @@ public class CalculateAverage_gonix {
TreeMap::new)); TreeMap::new));
System.out.println(res); System.out.println(res);
System.out.close();
} }
private static List<MappedByteBuffer> buildChunks(RandomAccessFile file) throws IOException { private static List<MappedByteBuffer> buildChunks(RandomAccessFile file) throws IOException {
@ -75,9 +76,8 @@ public class CalculateAverage_gonix {
} }
return chunks; return chunks;
} }
}
class Aggregator { private static class Aggregator {
private static final int MAX_STATIONS = 10_000; private static final int MAX_STATIONS = 10_000;
private static final int MAX_STATION_SIZE = Math.ceilDiv(100, 8) + 5; private static final int MAX_STATION_SIZE = Math.ceilDiv(100, 8) + 5;
private static final int INDEX_SIZE = 1024 * 1024; private static final int INDEX_SIZE = 1024 * 1024;
@ -132,66 +132,64 @@ class Aggregator {
while (pos < limit) { while (pos < limit) {
int start = pos; int start = pos;
int hash = 0; long keyLong = buf.getLong(pos);
long tail = 0; long valueSepMark = valueSepMark(keyLong);
while (true) { if (valueSepMark != 0) {
// Seen this trick used in multiple other solutions. int tailBits = tailBits(valueSepMark);
// Nice breakdown here: https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord pos += valueOffset(tailBits);
long tmpLong = buf.getLong(pos); // assert (UNSAFE.getByte(pos - 1) == ';') : "Expected ';' (1), pos=" + (pos - startAddr);
long match = tmpLong ^ 0x3B3B3B3B_3B3B3B3BL; // 3B == ';' long tailAndLen = tailAndLen(tailBits, keyLong, pos - start - 1);
match = ((match - 0x01010101_01010101L) & (~match & 0x80808080_80808080L));
if (match == 0) { long valueLong = buf.getLong(pos);
hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFFFFFF); int decimalSepMark = decimalSepMark(valueLong);
pos += 8; pos += nextKeyOffset(decimalSepMark);
// assert (UNSAFE.getByte(pos - 1) == '\n') : "Expected '\\n' (1), pos=" + (pos - startAddr);
int measurement = decimalValue(decimalSepMark, valueLong);
add1(buf, start, tailAndLen, hash(hash1(tailAndLen)), measurement);
continue; continue;
} }
int tailBits = Long.numberOfTrailingZeros(match >>> 7); pos += 8;
long tailMask = ~(-1L << tailBits); long keyLong1 = keyLong;
tail = tmpLong & tailMask; keyLong = buf.getLong(pos);
hash = ((33 * hash) ^ (int) (tail & 0xFFFFFFFF)) + (int) ((tail >>> 33) & 0xFFFFFFFF); valueSepMark = valueSepMark(keyLong);
pos += tailBits >> 3; if (valueSepMark != 0) {
break; int tailBits = tailBits(valueSepMark);
pos += valueOffset(tailBits);
// assert (UNSAFE.getByte(pos - 1) == ';') : "Expected ';' (2), pos=" + (pos - startAddr);
long tailAndLen = tailAndLen(tailBits, keyLong, pos - start - 1);
long valueLong = buf.getLong(pos);
int decimalSepMark = decimalSepMark(valueLong);
pos += nextKeyOffset(decimalSepMark);
// assert (UNSAFE.getByte(pos - 1) == '\n') : "Expected '\\n' (2), pos=" + (pos - startAddr);
int measurement = decimalValue(decimalSepMark, valueLong);
add2(buf, start, keyLong1, tailAndLen, hash(hash(hash1(keyLong1), tailAndLen)), measurement);
continue;
} }
hash = (33 * hash) ^ (hash >>> 15);
int lenInLongs = (pos - start) >> 3;
long tailAndLen = (tail << 8) | (lenInLongs & 0xFF);
// assert (buf.get(pos) == ';') : "Expected ';'";
pos++;
int measurement; long hash = hash1(keyLong1);
{ do {
// Seen this trick used in multiple other solutions. pos += 8;
// Looks like the original author is @merykitty. hash = hash(hash, keyLong);
long tmpLong = buf.getLong(pos); keyLong = buf.getLong(pos);
valueSepMark = valueSepMark(keyLong);
} while (valueSepMark == 0);
int tailBits = tailBits(valueSepMark);
pos += valueOffset(tailBits);
// assert (UNSAFE.getByte(pos - 1) == ';') : "Expected ';' (N), pos=" + (pos - startAddr);
long tailAndLen = tailAndLen(tailBits, keyLong, pos - start - 1);
hash = hash(hash, tailAndLen);
// The 4th binary digit of the ascii of a digit is 1 while long valueLong = buf.getLong(pos);
// that of the '.' is 0. This finds the decimal separator int decimalSepMark = decimalSepMark(valueLong);
// The value can be 12, 20, 28 pos += nextKeyOffset(decimalSepMark);
int decimalSepPos = Long.numberOfTrailingZeros(~tmpLong & 0x10101000); // assert (UNSAFE.getByte(pos - 1) == '\n') : "Expected '\\n' (N), pos=" + (pos - startAddr);
int shift = 28 - decimalSepPos; int measurement = decimalValue(decimalSepMark, valueLong);
// signed is -1 if negative, 0 otherwise
long signed = (~tmpLong << 59) >> 63;
long designMask = ~(signed & 0xFF);
// Align the number to a specific position and transform the ascii code
// to actual digit value in each byte
long digits = ((tmpLong & designMask) << shift) & 0x0F000F0F00L;
// Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit) addN(buf, start, tailAndLen, hash(hash), measurement);
// 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
// 0x000000UU00TTHH00 +
// 0x00UU00TTHH000000 * 10 +
// 0xUU00TTHH00000000 * 100
// Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400
// This results in our value lies in the bit 32 to 41 of this product
// That was close :)
long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
measurement = (int) ((absValue ^ signed) - signed);
pos += (decimalSepPos >>> 3) + 3;
}
// assert (buf.get(pos - 1) == '\n') : "Expected '\\n'";
add(buf, start, tailAndLen, hash, measurement);
} }
return this; return this;
@ -203,10 +201,103 @@ class Aggregator {
.mapToObj(offset -> new Entry(mem, offset)); .mapToObj(offset -> new Entry(mem, offset));
} }
private void add(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) { private static long hash1(long value) {
return value;
}
private static long hash(long hash, long value) {
return hash ^ value;
}
private static int hash(long hash) {
hash *= 0x9E3779B97F4A7C15L; // Fibonacci hashing multiplier
return (int) (hash >>> 39);
}
private static long valueSepMark(long keyLong) {
// Seen this trick used in multiple other solutions.
// Nice breakdown here: https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
long match = keyLong ^ 0x3B3B3B3B_3B3B3B3BL; // 3B == ';'
match = (match - 0x01010101_01010101L) & (~match & 0x80808080_80808080L);
return match;
}
private static int tailBits(long valueSepMark) {
return Long.numberOfTrailingZeros(valueSepMark >>> 7);
}
private static int valueOffset(int tailBits) {
return (int) (tailBits >>> 3) + 1;
}
private static long tailAndLen(int tailBits, long keyLong, long keyLen) {
long tailMask = ~(-1L << tailBits);
long tail = keyLong & tailMask;
return (tail << 8) | ((keyLen >> 3) & 0xFF);
}
private static int decimalSepMark(long value) {
// Seen this trick used in multiple other solutions.
// Looks like the original author is @merykitty.
// The 4th binary digit of the ascii of a digit is 1 while
// that of the '.' is 0. This finds the decimal separator
// The value can be 12, 20, 28
return Long.numberOfTrailingZeros(~value & 0x10101000);
}
private static int decimalValue(int decimalSepMark, long value) {
// Seen this trick used in multiple other solutions.
// Looks like the original author is @merykitty.
int shift = 28 - decimalSepMark;
// signed is -1 if negative, 0 otherwise
long signed = (~value << 59) >> 63;
long designMask = ~(signed & 0xFF);
// Align the number to a specific position and transform the ascii code
// to actual digit value in each byte
long digits = ((value & designMask) << shift) & 0x0F000F0F00L;
// Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit)
// 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
// 0x000000UU00TTHH00 +
// 0x00UU00TTHH000000 * 10 +
// 0xUU00TTHH00000000 * 100
// Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400
// This results in our value lies in the bit 32 to 41 of this product
// That was close :)
long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
return (int) ((absValue ^ signed) - signed);
}
private static int nextKeyOffset(int decimalSepMark) {
return (decimalSepMark >>> 3) + 3;
}
private void add1(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) {
int idx = hash & INDEX_MASK; int idx = hash & INDEX_MASK;
for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) { for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) {
if (update(index[idx], buf, start, tailAndLen, measurement)) { if (update1(index[idx], tailAndLen, measurement)) {
return;
}
}
index[idx] = create(buf, start, tailAndLen, measurement);
}
private void add2(ByteBuffer buf, int start, long keyLong, long tailAndLen, int hash, int measurement) {
int idx = hash & INDEX_MASK;
for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) {
if (update2(index[idx], keyLong, tailAndLen, measurement)) {
return;
}
}
index[idx] = create(buf, start, tailAndLen, measurement);
}
private void addN(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) {
int idx = hash & INDEX_MASK;
for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) {
if (updateN(index[idx], buf, start, tailAndLen, measurement)) {
return; return;
} }
} }
@ -236,7 +327,23 @@ class Aggregator {
return offset; return offset;
} }
private boolean update(int offset, ByteBuffer buf, int start, long tailAndLen, int measurement) { private boolean update1(int offset, long tailAndLen, int measurement) {
if (mem[offset] != tailAndLen) {
return false;
}
updateStats(offset + 1, measurement);
return true;
}
private boolean update2(int offset, long keyLong, long tailAndLen, int measurement) {
if (mem[offset] != tailAndLen || mem[offset + 1] != keyLong) {
return false;
}
updateStats(offset + 2, measurement);
return true;
}
private boolean updateN(int offset, ByteBuffer buf, int start, long tailAndLen, int measurement) {
var mem = this.mem; var mem = this.mem;
if (mem[offset] != tailAndLen) { if (mem[offset] != tailAndLen) {
return false; return false;
@ -251,7 +358,11 @@ class Aggregator {
memPos += 1; memPos += 1;
bufPos += 8; bufPos += 8;
} }
updateStats(memPos, measurement);
return true;
}
private void updateStats(int memPos, int measurement) {
mem[memPos + FLD_COUNT] += 1; mem[memPos + FLD_COUNT] += 1;
mem[memPos + FLD_SUM] += measurement; mem[memPos + FLD_SUM] += measurement;
if (measurement < mem[memPos + FLD_MIN]) { if (measurement < mem[memPos + FLD_MIN]) {
@ -260,8 +371,6 @@ class Aggregator {
if (measurement > mem[memPos + FLD_MAX]) { if (measurement > mem[memPos + FLD_MAX]) {
mem[memPos + FLD_MAX] = measurement; mem[memPos + FLD_MAX] = measurement;
} }
return true;
} }
public static class Entry { public static class Entry {
@ -320,3 +429,5 @@ class Aggregator {
} }
} }
} }
}