Second attempt with various improvements (#510)
* Initial chunked impl * Bytes instead of chars * Improved number parsing * Custom hashmap * Graal and some tuning * Fix segmenting * Fix casing * Unsafe * Inlining hash calc * Improved loop * Cleanup * Speeding up equals * Simplifying hash * Replace concurrenthashmap with lock * Small changes * Script reorg * Native * Lots of inlining and improvements * Add back length check * Fixes * Small changes --------- Co-authored-by: Jamal Mulla <j.mulla@mwam.com>
This commit is contained in:
parent
b91c95a498
commit
e639e2a045
@ -15,5 +15,11 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
JAVA_OPTS="--enable-preview -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -XX:+UseTransparentHugePages"
|
||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_JamalMulla
|
||||
|
||||
|
||||
if [ -f target/CalculateAverage_JamalMulla_image ]; then
|
||||
target/CalculateAverage_JamalMulla_image
|
||||
else
|
||||
JAVA_OPTS="--enable-preview -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -XX:+UseTransparentHugePages -XX:-TieredCompilation"
|
||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_JamalMulla
|
||||
fi
|
@ -16,4 +16,10 @@
|
||||
#
|
||||
|
||||
source "$HOME/.sdkman/bin/sdkman-init.sh"
|
||||
sdk use java 21.0.1-graal 1>&2
|
||||
sdk use java 21.0.2-graal 1>&2
|
||||
|
||||
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
|
||||
if [ ! -f target/CalculateAverage_JamalMulla_image ]; then
|
||||
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview --strict-image-heap --link-at-build-time -R:MaxHeapSize=64m -da -dsa --no-fallback --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_JamalMulla"
|
||||
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_JamalMulla_image dev.morling.onebrc.CalculateAverage_JamalMulla
|
||||
fi
|
@ -21,21 +21,32 @@ import java.io.IOException;
|
||||
import java.io.RandomAccessFile;
|
||||
import java.lang.foreign.Arena;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.MappedByteBuffer;
|
||||
import java.nio.channels.FileChannel;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
public class CalculateAverage_JamalMulla {
|
||||
|
||||
private static final Map<String, ResultRow> global = new HashMap<>();
|
||||
private static final long ALL_SEMIS = 0x3B3B3B3B3B3B3B3BL;
|
||||
private static final Map<String, ResultRow> global = new TreeMap<>();
|
||||
private static final String FILE = "./measurements.txt";
|
||||
private static final Unsafe UNSAFE = initUnsafe();
|
||||
private static final Lock lock = new ReentrantLock();
|
||||
private static final int FNV_32_INIT = 0x811c9dc5;
|
||||
private static final int FNV_32_PRIME = 0x01000193;
|
||||
private static final long FXSEED = 0x517cc1b727220a95L;
|
||||
|
||||
private static final long[] masks = {
|
||||
0x0,
|
||||
0x00000000000000FFL,
|
||||
0x000000000000FFFFL,
|
||||
0x0000000000FFFFFFL,
|
||||
0x00000000FFFFFFFFL,
|
||||
0x000000FFFFFFFFFFL,
|
||||
0x0000FFFFFFFFFFFFL,
|
||||
0x00FFFFFFFFFFFFFFL
|
||||
};
|
||||
|
||||
private static Unsafe initUnsafe() {
|
||||
try {
|
||||
@ -53,12 +64,16 @@ public class CalculateAverage_JamalMulla {
|
||||
private int max;
|
||||
private long sum;
|
||||
private int count;
|
||||
private final long keyStart;
|
||||
private final byte keyLength;
|
||||
|
||||
private ResultRow(int v) {
|
||||
private ResultRow(int v, final long keyStart, final byte keyLength) {
|
||||
this.min = v;
|
||||
this.max = v;
|
||||
this.sum = v;
|
||||
this.count = 1;
|
||||
this.keyStart = keyStart;
|
||||
this.keyLength = keyLength;
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
@ -68,236 +83,197 @@ public class CalculateAverage_JamalMulla {
|
||||
private double round(double value) {
|
||||
return Math.round(value) / 10.0;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private record Chunk(Long start, Long length) {
|
||||
}
|
||||
|
||||
static List<Chunk> getChunks(int numThreads, FileChannel channel) throws IOException {
|
||||
static Chunk[] getChunks(int numThreads, FileChannel channel) throws IOException {
|
||||
// get all chunk boundaries
|
||||
final long filebytes = channel.size();
|
||||
final long roughChunkSize = filebytes / numThreads;
|
||||
final List<Chunk> chunks = new ArrayList<>(numThreads);
|
||||
final Chunk[] chunks = new Chunk[numThreads];
|
||||
final long mappedAddress = channel.map(FileChannel.MapMode.READ_ONLY, 0, filebytes, Arena.global()).address();
|
||||
long chunkStart = 0;
|
||||
long chunkLength = Math.min(filebytes - chunkStart - 1, roughChunkSize);
|
||||
int i = 0;
|
||||
while (chunkStart < filebytes) {
|
||||
// unlikely we need to read more than this many bytes to find the next newline
|
||||
MappedByteBuffer mbb = channel.map(FileChannel.MapMode.READ_ONLY, chunkStart + chunkLength,
|
||||
Math.min(Math.min(filebytes - chunkStart - chunkLength, chunkLength), 100));
|
||||
|
||||
while (mbb.get() != 0xA /* \n */) {
|
||||
while (UNSAFE.getByte(mappedAddress + chunkStart + chunkLength) != 0xA /* \n */) {
|
||||
chunkLength++;
|
||||
}
|
||||
|
||||
chunks.add(new Chunk(mappedAddress + chunkStart, chunkLength + 1));
|
||||
chunks[i++] = new Chunk(mappedAddress + chunkStart, chunkLength + 1);
|
||||
// to skip the nl in the next chunk
|
||||
chunkStart += chunkLength + 1;
|
||||
chunkLength = Math.min(filebytes - chunkStart - 1, roughChunkSize);
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
private static class CalculateTask implements Runnable {
|
||||
private static void run(Chunk chunk) {
|
||||
|
||||
private final SimplerHashMap results;
|
||||
private final Chunk chunk;
|
||||
|
||||
public CalculateTask(Chunk chunk) {
|
||||
this.results = new SimplerHashMap();
|
||||
this.chunk = chunk;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
// no names bigger than this
|
||||
final byte[] nameBytes = new byte[100];
|
||||
short nameIndex = 0;
|
||||
int ot;
|
||||
// fnv hash
|
||||
int hash = FNV_32_INIT;
|
||||
|
||||
long i = chunk.start;
|
||||
final long cl = chunk.start + chunk.length;
|
||||
while (i < cl) {
|
||||
byte c;
|
||||
while ((c = UNSAFE.getByte(i++)) != 0x3B /* semi-colon */) {
|
||||
nameBytes[nameIndex++] = c;
|
||||
hash ^= c;
|
||||
hash *= FNV_32_PRIME;
|
||||
}
|
||||
|
||||
// temperature value follows
|
||||
c = UNSAFE.getByte(i++);
|
||||
// we know the val has to be between -99.9 and 99.8
|
||||
// always with a single fractional digit
|
||||
// represented as a byte array of either 4 or 5 characters
|
||||
if (c == 0x2D /* minus sign */) {
|
||||
// could be either n.x or nn.x
|
||||
if (UNSAFE.getByte(i + 3) == 0xA) {
|
||||
ot = (UNSAFE.getByte(i++) - 48) * 10; // char 1
|
||||
}
|
||||
else {
|
||||
ot = (UNSAFE.getByte(i++) - 48) * 100; // char 1
|
||||
ot += (UNSAFE.getByte(i++) - 48) * 10; // char 2
|
||||
}
|
||||
i++; // skip dot
|
||||
ot += (UNSAFE.getByte(i++) - 48); // char 2
|
||||
ot = -ot;
|
||||
}
|
||||
else {
|
||||
// could be either n.x or nn.x
|
||||
if (UNSAFE.getByte(i + 2) == 0xA) {
|
||||
ot = (c - 48) * 10; // char 1
|
||||
}
|
||||
else {
|
||||
ot = (c - 48) * 100; // char 1
|
||||
ot += (UNSAFE.getByte(i++) - 48) * 10; // char 2
|
||||
}
|
||||
i++; // skip dot
|
||||
ot += (UNSAFE.getByte(i++) - 48); // char 3
|
||||
}
|
||||
|
||||
i++;// nl
|
||||
hash &= 65535;
|
||||
results.putOrMerge(nameBytes, nameIndex, hash, ot);
|
||||
// reset
|
||||
nameIndex = 0;
|
||||
hash = 0x811c9dc5;
|
||||
}
|
||||
|
||||
// merge results with overall results
|
||||
List<MapEntry> all = results.getAll();
|
||||
lock.lock();
|
||||
try {
|
||||
for (MapEntry me : all) {
|
||||
ResultRow rr;
|
||||
ResultRow lr = me.row;
|
||||
if ((rr = global.get(me.key)) != null) {
|
||||
rr.min = Math.min(rr.min, lr.min);
|
||||
rr.max = Math.max(rr.max, lr.max);
|
||||
rr.count += lr.count;
|
||||
rr.sum += lr.sum;
|
||||
}
|
||||
else {
|
||||
global.put(me.key, lr);
|
||||
}
|
||||
}
|
||||
}
|
||||
finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException, InterruptedException {
|
||||
FileChannel channel = new RandomAccessFile(FILE, "r").getChannel();
|
||||
int numThreads = 1;
|
||||
if (channel.size() > 64000) {
|
||||
numThreads = Runtime.getRuntime().availableProcessors();
|
||||
}
|
||||
List<Chunk> chunks = getChunks(numThreads, channel);
|
||||
List<Thread> threads = new ArrayList<>();
|
||||
for (Chunk chunk : chunks) {
|
||||
Thread thread = new Thread(new CalculateTask(chunk));
|
||||
thread.setPriority(Thread.MAX_PRIORITY);
|
||||
thread.start();
|
||||
threads.add(thread);
|
||||
}
|
||||
for (Thread t : threads) {
|
||||
t.join();
|
||||
}
|
||||
// create treemap just to sort
|
||||
System.out.println(new TreeMap<>(global));
|
||||
}
|
||||
|
||||
record MapEntry(String key, ResultRow row) {
|
||||
}
|
||||
|
||||
static class SimplerHashMap {
|
||||
// can't have more than 10000 unique keys but want to match max hash
|
||||
final int MAPSIZE = 65536;
|
||||
final ResultRow[] slots = new ResultRow[MAPSIZE];
|
||||
final byte[][] keys = new byte[MAPSIZE][];
|
||||
|
||||
public void putOrMerge(final byte[] key, final short length, final int hash, final int temp) {
|
||||
int slot = hash;
|
||||
ResultRow slotValue;
|
||||
byte nameLength;
|
||||
int temp;
|
||||
long hash;
|
||||
|
||||
long i = chunk.start;
|
||||
final long cl = chunk.start + chunk.length;
|
||||
long word;
|
||||
long hs;
|
||||
long start;
|
||||
byte c;
|
||||
int slot;
|
||||
long n;
|
||||
ResultRow slotValue;
|
||||
|
||||
while (i < cl) {
|
||||
start = i;
|
||||
hash = 0;
|
||||
|
||||
word = UNSAFE.getLong(i);
|
||||
|
||||
while (true) {
|
||||
n = word ^ ALL_SEMIS;
|
||||
hs = (n - 0x0101010101010101L) & (~n & 0x8080808080808080L);
|
||||
if (hs != 0)
|
||||
break;
|
||||
hash = (hash ^ word) * FXSEED;
|
||||
i += 8;
|
||||
word = UNSAFE.getLong(i);
|
||||
}
|
||||
|
||||
i += Long.numberOfTrailingZeros(hs) >> 3;
|
||||
|
||||
// hash of what's left ((hs >>> 7) - 1) masks off the bytes from word that are before the semicolon
|
||||
hash = (hash ^ word & (hs >>> 7) - 1) * FXSEED;
|
||||
nameLength = (byte) (i++ - start);
|
||||
|
||||
// temperature value follows
|
||||
c = UNSAFE.getByte(i++);
|
||||
// we know the val has to be between -99.9 and 99.8
|
||||
// always with a single fractional digit
|
||||
// represented as a byte array of either 4 or 5 characters
|
||||
if (c != 0x2D /* minus sign */) {
|
||||
// could be either n.x or nn.x
|
||||
if (UNSAFE.getByte(i + 2) == 0xA) {
|
||||
temp = (c - 48) * 10; // char 1
|
||||
}
|
||||
else {
|
||||
temp = (c - 48) * 100; // char 1
|
||||
temp += (UNSAFE.getByte(i++) - 48) * 10; // char 2
|
||||
}
|
||||
temp += (UNSAFE.getByte(++i) - 48); // char 3
|
||||
}
|
||||
else {
|
||||
// could be either n.x or nn.x
|
||||
if (UNSAFE.getByte(i + 3) == 0xA) {
|
||||
temp = (UNSAFE.getByte(i) - 48) * 10; // char 1
|
||||
i += 2;
|
||||
}
|
||||
else {
|
||||
temp = (UNSAFE.getByte(i) - 48) * 100; // char 1
|
||||
temp += (UNSAFE.getByte(i + 1) - 48) * 10; // char 2
|
||||
i += 3;
|
||||
}
|
||||
temp += (UNSAFE.getByte(i) - 48); // char 2
|
||||
temp = -temp;
|
||||
}
|
||||
i += 2;
|
||||
|
||||
// xor folding
|
||||
slot = (int) (hash ^ hash >> 32) & 65535;
|
||||
|
||||
// Linear probe for open slot
|
||||
while ((slotValue = slots[slot]) != null && (keys[slot].length != length || !unsafeEquals(keys[slot], key, length))) {
|
||||
slot++;
|
||||
while ((slotValue = slots[slot]) != null && (slotValue.keyLength != nameLength || !unsafeEquals(slotValue.keyStart, start, nameLength))) {
|
||||
slot = (slot + 1) % MAPSIZE;
|
||||
}
|
||||
|
||||
// existing
|
||||
if (slotValue != null) {
|
||||
slotValue.min = Math.min(slotValue.min, temp);
|
||||
slotValue.max = Math.max(slotValue.max, temp);
|
||||
slotValue.sum += temp;
|
||||
slotValue.count++;
|
||||
return;
|
||||
}
|
||||
if (temp > slotValue.max) {
|
||||
slotValue.max = temp;
|
||||
continue;
|
||||
}
|
||||
if (temp < slotValue.min)
|
||||
slotValue.min = temp;
|
||||
|
||||
// new value
|
||||
slots[slot] = new ResultRow(temp);
|
||||
byte[] bytes = new byte[length];
|
||||
System.arraycopy(key, 0, bytes, 0, length);
|
||||
keys[slot] = bytes;
|
||||
}
|
||||
else {
|
||||
// new value
|
||||
slots[slot] = new ResultRow(temp, start, nameLength);
|
||||
}
|
||||
}
|
||||
|
||||
static boolean unsafeEquals(final byte[] a, final byte[] b, final short length) {
|
||||
// byte by byte comparisons are slow, so do as big chunks as possible
|
||||
final int baseOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET;
|
||||
|
||||
short i = 0;
|
||||
// round down to nearest power of 8
|
||||
for (; i < (length & -8); i += 8) {
|
||||
if (UNSAFE.getLong(a, i + baseOffset) != UNSAFE.getLong(b, i + baseOffset)) {
|
||||
return false;
|
||||
// merge results with overall results
|
||||
ResultRow rr;
|
||||
String key;
|
||||
byte[] bytes;
|
||||
lock.lock();
|
||||
try {
|
||||
for (ResultRow resultRow : slots) {
|
||||
if (resultRow != null) {
|
||||
bytes = new byte[resultRow.keyLength];
|
||||
// copy the name bytes
|
||||
UNSAFE.copyMemory(null, resultRow.keyStart, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, resultRow.keyLength);
|
||||
key = new String(bytes, StandardCharsets.UTF_8);
|
||||
if ((rr = global.get(key)) != null) {
|
||||
rr.min = Math.min(rr.min, resultRow.min);
|
||||
rr.max = Math.max(rr.max, resultRow.max);
|
||||
rr.count += resultRow.count;
|
||||
rr.sum += resultRow.sum;
|
||||
}
|
||||
else {
|
||||
global.put(key, resultRow);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (i == length) {
|
||||
return true;
|
||||
}
|
||||
// leftover ints
|
||||
for (; i < (length - i & -4); i += 4) {
|
||||
if (UNSAFE.getInt(a, i + baseOffset) != UNSAFE.getInt(b, i + baseOffset)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (i == length) {
|
||||
return true;
|
||||
}
|
||||
// leftover shorts
|
||||
for (; i < (length - i & -2); i += 2) {
|
||||
if (UNSAFE.getShort(a, i + baseOffset) != UNSAFE.getShort(b, i + baseOffset)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (i == length) {
|
||||
return true;
|
||||
}
|
||||
// leftover bytes
|
||||
for (; i < (length - i); i++) {
|
||||
if (UNSAFE.getByte(a, i + baseOffset) != UNSAFE.getByte(b, i + baseOffset)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
// Get all pairs
|
||||
public List<MapEntry> getAll() {
|
||||
final List<MapEntry> result = new ArrayList<>(slots.length);
|
||||
for (int i = 0; i < slots.length; i++) {
|
||||
ResultRow slotValue = slots[i];
|
||||
if (slotValue != null) {
|
||||
result.add(new MapEntry(new String(keys[i], StandardCharsets.UTF_8), slotValue));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
static boolean unsafeEquals(final long a_address, final long b_address, final byte b_length) {
|
||||
// byte by byte comparisons are slow, so do as big chunks as possible
|
||||
byte i = 0;
|
||||
for (; i < (b_length & -8); i += 8) {
|
||||
if (UNSAFE.getLong(a_address + i) != UNSAFE.getLong(b_address + i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (i == b_length)
|
||||
return true;
|
||||
return (UNSAFE.getLong(a_address + i) & masks[b_length - i]) == (UNSAFE.getLong(b_address + i) & masks[b_length - i]);
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException, InterruptedException {
|
||||
int numThreads = 1;
|
||||
FileChannel channel = new RandomAccessFile(FILE, "r").getChannel();
|
||||
if (channel.size() > 64000) {
|
||||
numThreads = Runtime.getRuntime().availableProcessors();
|
||||
}
|
||||
Chunk[] chunks = getChunks(numThreads, channel);
|
||||
Thread[] threads = new Thread[chunks.length];
|
||||
for (int i = 0; i < chunks.length; i++) {
|
||||
int finalI = i;
|
||||
Thread thread = new Thread(() -> run(chunks[finalI]));
|
||||
thread.setPriority(Thread.MAX_PRIORITY);
|
||||
thread.start();
|
||||
threads[i] = thread;
|
||||
}
|
||||
for (Thread t : threads) {
|
||||
t.join();
|
||||
}
|
||||
System.out.println(global);
|
||||
channel.close();
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user