From e639e2a045371ab0be51404767a42f22f689cf2c Mon Sep 17 00:00:00 2001 From: Jamal Mulla Date: Wed, 31 Jan 2024 21:09:25 +0000 Subject: [PATCH] Second attempt with various improvements (#510) * Initial chunked impl * Bytes instead of chars * Improved number parsing * Custom hashmap * Graal and some tuning * Fix segmenting * Fix casing * Unsafe * Inlining hash calc * Improved loop * Cleanup * Speeding up equals * Simplifying hash * Replace concurrenthashmap with lock * Small changes * Script reorg * Native * Lots of inlining and improvements * Add back length check * Fixes * Small changes --------- Co-authored-by: Jamal Mulla --- calculate_average_JamalMulla.sh | 10 +- prepare_JamalMulla.sh | 8 +- .../onebrc/CalculateAverage_JamalMulla.java | 364 ++++++++---------- 3 files changed, 185 insertions(+), 197 deletions(-) diff --git a/calculate_average_JamalMulla.sh b/calculate_average_JamalMulla.sh index 228d56b..119263b 100755 --- a/calculate_average_JamalMulla.sh +++ b/calculate_average_JamalMulla.sh @@ -15,5 +15,11 @@ # limitations under the License. # -JAVA_OPTS="--enable-preview -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -XX:+UseTransparentHugePages" -java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_JamalMulla + + +if [ -f target/CalculateAverage_JamalMulla_image ]; then + target/CalculateAverage_JamalMulla_image +else + JAVA_OPTS="--enable-preview -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -XX:+UseTransparentHugePages -XX:-TieredCompilation" + java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_JamalMulla +fi \ No newline at end of file diff --git a/prepare_JamalMulla.sh b/prepare_JamalMulla.sh index ec0f35f..d950d43 100755 --- a/prepare_JamalMulla.sh +++ b/prepare_JamalMulla.sh @@ -16,4 +16,10 @@ # source "$HOME/.sdkman/bin/sdkman-init.sh" -sdk use java 21.0.1-graal 1>&2 \ No newline at end of file +sdk use java 21.0.2-graal 1>&2 + +# ./mvnw clean verify removes target/ and will re-trigger native image creation. +if [ ! -f target/CalculateAverage_JamalMulla_image ]; then + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview --strict-image-heap --link-at-build-time -R:MaxHeapSize=64m -da -dsa --no-fallback --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_JamalMulla" + native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_JamalMulla_image dev.morling.onebrc.CalculateAverage_JamalMulla +fi \ No newline at end of file diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_JamalMulla.java b/src/main/java/dev/morling/onebrc/CalculateAverage_JamalMulla.java index 7705885..7daf199 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_JamalMulla.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_JamalMulla.java @@ -21,21 +21,32 @@ import java.io.IOException; import java.io.RandomAccessFile; import java.lang.foreign.Arena; import java.lang.reflect.Field; -import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; -import java.util.*; +import java.util.Map; +import java.util.TreeMap; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; public class CalculateAverage_JamalMulla { - private static final Map global = new HashMap<>(); + private static final long ALL_SEMIS = 0x3B3B3B3B3B3B3B3BL; + private static final Map global = new TreeMap<>(); private static final String FILE = "./measurements.txt"; private static final Unsafe UNSAFE = initUnsafe(); private static final Lock lock = new ReentrantLock(); - private static final int FNV_32_INIT = 0x811c9dc5; - private static final int FNV_32_PRIME = 0x01000193; + private static final long FXSEED = 0x517cc1b727220a95L; + + private static final long[] masks = { + 0x0, + 0x00000000000000FFL, + 0x000000000000FFFFL, + 0x0000000000FFFFFFL, + 0x00000000FFFFFFFFL, + 0x000000FFFFFFFFFFL, + 0x0000FFFFFFFFFFFFL, + 0x00FFFFFFFFFFFFFFL + }; private static Unsafe initUnsafe() { try { @@ -53,12 +64,16 @@ public class CalculateAverage_JamalMulla { private int max; private long sum; private int count; + private final long keyStart; + private final byte keyLength; - private ResultRow(int v) { + private ResultRow(int v, final long keyStart, final byte keyLength) { this.min = v; this.max = v; this.sum = v; this.count = 1; + this.keyStart = keyStart; + this.keyLength = keyLength; } public String toString() { @@ -68,236 +83,197 @@ public class CalculateAverage_JamalMulla { private double round(double value) { return Math.round(value) / 10.0; } + } private record Chunk(Long start, Long length) { } - static List getChunks(int numThreads, FileChannel channel) throws IOException { + static Chunk[] getChunks(int numThreads, FileChannel channel) throws IOException { // get all chunk boundaries final long filebytes = channel.size(); final long roughChunkSize = filebytes / numThreads; - final List chunks = new ArrayList<>(numThreads); + final Chunk[] chunks = new Chunk[numThreads]; final long mappedAddress = channel.map(FileChannel.MapMode.READ_ONLY, 0, filebytes, Arena.global()).address(); long chunkStart = 0; long chunkLength = Math.min(filebytes - chunkStart - 1, roughChunkSize); + int i = 0; while (chunkStart < filebytes) { - // unlikely we need to read more than this many bytes to find the next newline - MappedByteBuffer mbb = channel.map(FileChannel.MapMode.READ_ONLY, chunkStart + chunkLength, - Math.min(Math.min(filebytes - chunkStart - chunkLength, chunkLength), 100)); - - while (mbb.get() != 0xA /* \n */) { + while (UNSAFE.getByte(mappedAddress + chunkStart + chunkLength) != 0xA /* \n */) { chunkLength++; } - chunks.add(new Chunk(mappedAddress + chunkStart, chunkLength + 1)); + chunks[i++] = new Chunk(mappedAddress + chunkStart, chunkLength + 1); // to skip the nl in the next chunk chunkStart += chunkLength + 1; chunkLength = Math.min(filebytes - chunkStart - 1, roughChunkSize); } + return chunks; } - private static class CalculateTask implements Runnable { + private static void run(Chunk chunk) { - private final SimplerHashMap results; - private final Chunk chunk; - - public CalculateTask(Chunk chunk) { - this.results = new SimplerHashMap(); - this.chunk = chunk; - } - - @Override - public void run() { - // no names bigger than this - final byte[] nameBytes = new byte[100]; - short nameIndex = 0; - int ot; - // fnv hash - int hash = FNV_32_INIT; - - long i = chunk.start; - final long cl = chunk.start + chunk.length; - while (i < cl) { - byte c; - while ((c = UNSAFE.getByte(i++)) != 0x3B /* semi-colon */) { - nameBytes[nameIndex++] = c; - hash ^= c; - hash *= FNV_32_PRIME; - } - - // temperature value follows - c = UNSAFE.getByte(i++); - // we know the val has to be between -99.9 and 99.8 - // always with a single fractional digit - // represented as a byte array of either 4 or 5 characters - if (c == 0x2D /* minus sign */) { - // could be either n.x or nn.x - if (UNSAFE.getByte(i + 3) == 0xA) { - ot = (UNSAFE.getByte(i++) - 48) * 10; // char 1 - } - else { - ot = (UNSAFE.getByte(i++) - 48) * 100; // char 1 - ot += (UNSAFE.getByte(i++) - 48) * 10; // char 2 - } - i++; // skip dot - ot += (UNSAFE.getByte(i++) - 48); // char 2 - ot = -ot; - } - else { - // could be either n.x or nn.x - if (UNSAFE.getByte(i + 2) == 0xA) { - ot = (c - 48) * 10; // char 1 - } - else { - ot = (c - 48) * 100; // char 1 - ot += (UNSAFE.getByte(i++) - 48) * 10; // char 2 - } - i++; // skip dot - ot += (UNSAFE.getByte(i++) - 48); // char 3 - } - - i++;// nl - hash &= 65535; - results.putOrMerge(nameBytes, nameIndex, hash, ot); - // reset - nameIndex = 0; - hash = 0x811c9dc5; - } - - // merge results with overall results - List all = results.getAll(); - lock.lock(); - try { - for (MapEntry me : all) { - ResultRow rr; - ResultRow lr = me.row; - if ((rr = global.get(me.key)) != null) { - rr.min = Math.min(rr.min, lr.min); - rr.max = Math.max(rr.max, lr.max); - rr.count += lr.count; - rr.sum += lr.sum; - } - else { - global.put(me.key, lr); - } - } - } - finally { - lock.unlock(); - } - } - } - - public static void main(String[] args) throws IOException, InterruptedException { - FileChannel channel = new RandomAccessFile(FILE, "r").getChannel(); - int numThreads = 1; - if (channel.size() > 64000) { - numThreads = Runtime.getRuntime().availableProcessors(); - } - List chunks = getChunks(numThreads, channel); - List threads = new ArrayList<>(); - for (Chunk chunk : chunks) { - Thread thread = new Thread(new CalculateTask(chunk)); - thread.setPriority(Thread.MAX_PRIORITY); - thread.start(); - threads.add(thread); - } - for (Thread t : threads) { - t.join(); - } - // create treemap just to sort - System.out.println(new TreeMap<>(global)); - } - - record MapEntry(String key, ResultRow row) { - } - - static class SimplerHashMap { // can't have more than 10000 unique keys but want to match max hash final int MAPSIZE = 65536; final ResultRow[] slots = new ResultRow[MAPSIZE]; - final byte[][] keys = new byte[MAPSIZE][]; - public void putOrMerge(final byte[] key, final short length, final int hash, final int temp) { - int slot = hash; - ResultRow slotValue; + byte nameLength; + int temp; + long hash; + + long i = chunk.start; + final long cl = chunk.start + chunk.length; + long word; + long hs; + long start; + byte c; + int slot; + long n; + ResultRow slotValue; + + while (i < cl) { + start = i; + hash = 0; + + word = UNSAFE.getLong(i); + + while (true) { + n = word ^ ALL_SEMIS; + hs = (n - 0x0101010101010101L) & (~n & 0x8080808080808080L); + if (hs != 0) + break; + hash = (hash ^ word) * FXSEED; + i += 8; + word = UNSAFE.getLong(i); + } + + i += Long.numberOfTrailingZeros(hs) >> 3; + + // hash of what's left ((hs >>> 7) - 1) masks off the bytes from word that are before the semicolon + hash = (hash ^ word & (hs >>> 7) - 1) * FXSEED; + nameLength = (byte) (i++ - start); + + // temperature value follows + c = UNSAFE.getByte(i++); + // we know the val has to be between -99.9 and 99.8 + // always with a single fractional digit + // represented as a byte array of either 4 or 5 characters + if (c != 0x2D /* minus sign */) { + // could be either n.x or nn.x + if (UNSAFE.getByte(i + 2) == 0xA) { + temp = (c - 48) * 10; // char 1 + } + else { + temp = (c - 48) * 100; // char 1 + temp += (UNSAFE.getByte(i++) - 48) * 10; // char 2 + } + temp += (UNSAFE.getByte(++i) - 48); // char 3 + } + else { + // could be either n.x or nn.x + if (UNSAFE.getByte(i + 3) == 0xA) { + temp = (UNSAFE.getByte(i) - 48) * 10; // char 1 + i += 2; + } + else { + temp = (UNSAFE.getByte(i) - 48) * 100; // char 1 + temp += (UNSAFE.getByte(i + 1) - 48) * 10; // char 2 + i += 3; + } + temp += (UNSAFE.getByte(i) - 48); // char 2 + temp = -temp; + } + i += 2; + + // xor folding + slot = (int) (hash ^ hash >> 32) & 65535; // Linear probe for open slot - while ((slotValue = slots[slot]) != null && (keys[slot].length != length || !unsafeEquals(keys[slot], key, length))) { - slot++; + while ((slotValue = slots[slot]) != null && (slotValue.keyLength != nameLength || !unsafeEquals(slotValue.keyStart, start, nameLength))) { + slot = (slot + 1) % MAPSIZE; } // existing if (slotValue != null) { - slotValue.min = Math.min(slotValue.min, temp); - slotValue.max = Math.max(slotValue.max, temp); slotValue.sum += temp; slotValue.count++; - return; - } + if (temp > slotValue.max) { + slotValue.max = temp; + continue; + } + if (temp < slotValue.min) + slotValue.min = temp; - // new value - slots[slot] = new ResultRow(temp); - byte[] bytes = new byte[length]; - System.arraycopy(key, 0, bytes, 0, length); - keys[slot] = bytes; + } + else { + // new value + slots[slot] = new ResultRow(temp, start, nameLength); + } } - static boolean unsafeEquals(final byte[] a, final byte[] b, final short length) { - // byte by byte comparisons are slow, so do as big chunks as possible - final int baseOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET; - - short i = 0; - // round down to nearest power of 8 - for (; i < (length & -8); i += 8) { - if (UNSAFE.getLong(a, i + baseOffset) != UNSAFE.getLong(b, i + baseOffset)) { - return false; + // merge results with overall results + ResultRow rr; + String key; + byte[] bytes; + lock.lock(); + try { + for (ResultRow resultRow : slots) { + if (resultRow != null) { + bytes = new byte[resultRow.keyLength]; + // copy the name bytes + UNSAFE.copyMemory(null, resultRow.keyStart, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, resultRow.keyLength); + key = new String(bytes, StandardCharsets.UTF_8); + if ((rr = global.get(key)) != null) { + rr.min = Math.min(rr.min, resultRow.min); + rr.max = Math.max(rr.max, resultRow.max); + rr.count += resultRow.count; + rr.sum += resultRow.sum; + } + else { + global.put(key, resultRow); + } } } - if (i == length) { - return true; - } - // leftover ints - for (; i < (length - i & -4); i += 4) { - if (UNSAFE.getInt(a, i + baseOffset) != UNSAFE.getInt(b, i + baseOffset)) { - return false; - } - } - if (i == length) { - return true; - } - // leftover shorts - for (; i < (length - i & -2); i += 2) { - if (UNSAFE.getShort(a, i + baseOffset) != UNSAFE.getShort(b, i + baseOffset)) { - return false; - } - } - if (i == length) { - return true; - } - // leftover bytes - for (; i < (length - i); i++) { - if (UNSAFE.getByte(a, i + baseOffset) != UNSAFE.getByte(b, i + baseOffset)) { - return false; - } - } - - return true; + } + finally { + lock.unlock(); } - // Get all pairs - public List getAll() { - final List result = new ArrayList<>(slots.length); - for (int i = 0; i < slots.length; i++) { - ResultRow slotValue = slots[i]; - if (slotValue != null) { - result.add(new MapEntry(new String(keys[i], StandardCharsets.UTF_8), slotValue)); - } - } - return result; - } } + static boolean unsafeEquals(final long a_address, final long b_address, final byte b_length) { + // byte by byte comparisons are slow, so do as big chunks as possible + byte i = 0; + for (; i < (b_length & -8); i += 8) { + if (UNSAFE.getLong(a_address + i) != UNSAFE.getLong(b_address + i)) { + return false; + } + } + if (i == b_length) + return true; + return (UNSAFE.getLong(a_address + i) & masks[b_length - i]) == (UNSAFE.getLong(b_address + i) & masks[b_length - i]); + } + + public static void main(String[] args) throws IOException, InterruptedException { + int numThreads = 1; + FileChannel channel = new RandomAccessFile(FILE, "r").getChannel(); + if (channel.size() > 64000) { + numThreads = Runtime.getRuntime().availableProcessors(); + } + Chunk[] chunks = getChunks(numThreads, channel); + Thread[] threads = new Thread[chunks.length]; + for (int i = 0; i < chunks.length; i++) { + int finalI = i; + Thread thread = new Thread(() -> run(chunks[finalI])); + thread.setPriority(Thread.MAX_PRIORITY); + thread.start(); + threads[i] = thread; + } + for (Thread t : threads) { + t.join(); + } + System.out.println(global); + channel.close(); + } }