diff --git a/calculate_average_jerrinot.sh b/calculate_average_jerrinot.sh index 1bbf680..8de06c3 100755 --- a/calculate_average_jerrinot.sh +++ b/calculate_average_jerrinot.sh @@ -17,5 +17,5 @@ # -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_mtopolnik::recordMeasurementAndAdvanceCursor" # -XX:InlineSmallCode=10000 -XX:-TieredCompilation -XX:CICompilerCount=2 -XX:CompileThreshold=1000\ -java --enable-preview \ +java -XX:+UseParallelGC --enable-preview \ --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jerrinot diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java index 6fb89bb..5373cb0 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java @@ -22,15 +22,24 @@ import java.io.RandomAccessFile; import java.lang.foreign.Arena; import java.lang.reflect.Field; import java.nio.channels.FileChannel.MapMode; -import java.util.Map; -import java.util.TreeMap; +import java.util.*; +/** + * I figured out it would be very hard to win the main competition of the One Billion Rows Challenge. + * but I think this code has a good chance to win a special prize for the Ugliest Solution ever! :) + * + * Anyway, if you can make sense out of not exactly idiomatic Java code, and you enjoy pushing performance limits + * then QuestDB - the fastest open-source time-series database - is hiring: https://questdb.io/careers/core-database-engineer/ + * + */ public class CalculateAverage_jerrinot { private static final Unsafe UNSAFE = unsafe(); private static final String MEASUREMENTS_TXT = "measurements.txt"; // todo: with hyper-threading enable we would be better of with availableProcessors / 2; // todo: validate the testing env. params. private static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors(); + // private static final int THREAD_COUNT = 4; + private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL; private static Unsafe unsafe() { @@ -72,7 +81,7 @@ public class CalculateAverage_jerrinot { Processor[] processors = new Processor[THREAD_COUNT]; Thread[] threads = new Thread[THREAD_COUNT]; - for (int i = 0; i < THREAD_COUNT; i++) { + for (int i = 0; i < THREAD_COUNT - 1; i++) { long startA = chunkStartOffsets[i * chunkPerThread]; long endA = chunkStartOffsets[i * chunkPerThread + 1]; long startB = chunkStartOffsets[i * chunkPerThread + 1]; @@ -89,8 +98,22 @@ public class CalculateAverage_jerrinot { thread.start(); } + int ownIndex = THREAD_COUNT - 1; + long startA = chunkStartOffsets[ownIndex * chunkPerThread]; + long endA = chunkStartOffsets[ownIndex * chunkPerThread + 1]; + long startB = chunkStartOffsets[ownIndex * chunkPerThread + 1]; + long endB = chunkStartOffsets[ownIndex * chunkPerThread + 2]; + long startC = chunkStartOffsets[ownIndex * chunkPerThread + 2]; + long endC = chunkStartOffsets[ownIndex * chunkPerThread + 3]; + long startD = chunkStartOffsets[ownIndex * chunkPerThread + 3]; + long endD = chunkStartOffsets[ownIndex * chunkPerThread + 4]; + Processor processor = new Processor(startA, endA, startB, endB, startC, endC, startD, endD); + processor.run(); + var accumulator = new TreeMap(); - for (int i = 0; i < THREAD_COUNT; i++) { + processor.accumulateStatus(accumulator); + + for (int i = 0; i < THREAD_COUNT - 1; i++) { Thread t = threads[i]; t.join(); processors[i].accumulateStatus(accumulator); @@ -131,7 +154,7 @@ public class CalculateAverage_jerrinot { private static class Processor implements Runnable { private static final int MAP_SLOT_COUNT = ceilPow2(10000); - private static final int STATION_MAX_NAME_BYTES = 104; + private static final int STATION_MAX_NAME_BYTES = 120; private static final long COUNT_OFFSET = 0; private static final long MIN_OFFSET = 4; @@ -162,23 +185,16 @@ public class CalculateAverage_jerrinot { private long endC; private long cursorD; private long endD; - private long maskA; - private long maskB; - private long maskC; - private long maskD; + + // private long maxClusterLen; // credit: merykitty - private long parseAndStoreTemperature(long startCursor, long baseEntryPtr) { - long word = UNSAFE.getLong(startCursor); - final long negateda = ~word; - final int dotPos = Long.numberOfTrailingZeros(negateda & 0x10101000); - final long signed = (negateda << 59) >> 63; - final long removeSignMask = ~(signed & 0xFF); - final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L; - final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; - final int temperature = (int) ((absValue ^ signed) - signed); - + private long parseAndStoreTemperature(long startCursor, long baseEntryPtr, long word) { + // long word = UNSAFE.getLong(startCursor); long countPtr = baseEntryPtr + COUNT_OFFSET; + int cnt = UNSAFE.getInt(countPtr); + UNSAFE.putInt(countPtr, cnt + 1); + long minPtr = baseEntryPtr + MIN_OFFSET; long maxPtr = baseEntryPtr + MAX_OFFSET; long sumPtr = baseEntryPtr + SUM_OFFSET; @@ -186,16 +202,23 @@ public class CalculateAverage_jerrinot { int min = UNSAFE.getInt(minPtr); int max = UNSAFE.getInt(maxPtr); long sum = UNSAFE.getLong(sumPtr); - // try if min/max intrinsics are paying off - // maybe braching is better? the branch is becoming more predictable with - // each new sample. - max = Math.max(max, temperature); - min = Math.min(min, temperature); + + final long negateda = ~word; + final int dotPos = Long.numberOfTrailingZeros(negateda & 0x10101000); + final long signed = (negateda << 59) >> 63; + final long removeSignMask = ~(signed & 0xFF); + final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L; + final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; + final int temperature = (int) ((absValue ^ signed) - signed); sum += temperature; - UNSAFE.putInt(countPtr, UNSAFE.getInt(countPtr) + 1); - UNSAFE.putInt(minPtr, min); - UNSAFE.putInt(maxPtr, max); UNSAFE.putLong(sumPtr, sum); + + if (temperature > max) { + UNSAFE.putInt(maxPtr, temperature); + } + if (temperature < min) { + UNSAFE.putInt(minPtr, temperature); + } return startCursor + (dotPos / 8) + 3; } @@ -227,13 +250,13 @@ public class CalculateAverage_jerrinot { int count = UNSAFE.getInt(baseAddress + COUNT_OFFSET); long sum = UNSAFE.getLong(baseAddress + SUM_OFFSET); - // todo: lambdas bootstrap probably cost us - accumulator.compute(name, (_, v) -> { - if (v == null) { - return new StationStats(min, max, count, sum); - } - return new StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum); - }); + var v = accumulator.get(name); + if (v == null) { + accumulator.put(name, new StationStats(min, max, count, sum)); + } + else { + accumulator.put(name, new StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum)); + } } } @@ -260,11 +283,22 @@ public class CalculateAverage_jerrinot { private void doTail() { // todo: we would be probably better of without all that code dup. ("compilers hates him!") // System.out.println("done ILP"); + doOne(cursorA, endA); + // System.out.println("done A"); + doOne(cursorB, endB); + // System.out.println("done B"); + doOne(cursorC, endC); + // System.out.println("done C"); + doOne(cursorD, endD); + // System.out.println("done D"); + } + + private void doOne(long cursorA, long endA) { while (cursorA < endA) { long startA = cursorA; long delimiterWordA = UNSAFE.getLong(cursorA); long hashA = 0; - maskA = getDelimiterMask(delimiterWordA); + long maskA = getDelimiterMask(delimiterWordA); while (maskA == 0) { hashA ^= delimiterWordA; cursorA += 8; @@ -273,81 +307,15 @@ public class CalculateAverage_jerrinot { } final int delimiterByteA = Long.numberOfTrailingZeros(maskA); final long semicolonA = cursorA + (delimiterByteA >> 3); - final long maskedWordA = delimiterWordA & ((maskA >>> 7) - 1); + final long maskedWordA = delimiterWordA & ((maskA - 1) ^ maskA) >>> 8; hashA ^= maskedWordA; int intHashA = (int) (hashA ^ (hashA >> 32)); intHashA = intHashA ^ (intHashA >> 17); long baseEntryPtrA = getOrCreateEntryBaseOffset(semicolonA, startA, intHashA, maskedWordA); - cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA); + long temperatureWordA = UNSAFE.getLong(semicolonA + 1); + cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA, temperatureWordA); } - // System.out.println("done A"); - while (cursorB < endB) { - long startB = cursorB; - long delimiterWordB = UNSAFE.getLong(cursorB); - long hashB = 0; - maskB = getDelimiterMask(delimiterWordB); - while (maskB == 0) { - hashB ^= delimiterWordB; - cursorB += 8; - delimiterWordB = UNSAFE.getLong(cursorB); - maskB = getDelimiterMask(delimiterWordB); - } - final int delimiterByteB = Long.numberOfTrailingZeros(maskB); - final long semicolonB = cursorB + (delimiterByteB >> 3); - final long maskedWordB = delimiterWordB & ((maskB >>> 7) - 1); - hashB ^= maskedWordB; - int intHashB = (int) (hashB ^ (hashB >> 32)); - intHashB = intHashB ^ (intHashB >> 17); - - long baseEntryPtrB = getOrCreateEntryBaseOffset(semicolonB, startB, intHashB, maskedWordB); - cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB); - } - // System.out.println("done B"); - while (cursorC < endC) { - long startC = cursorC; - long delimiterWordC = UNSAFE.getLong(cursorC); - long hashC = 0; - maskC = getDelimiterMask(delimiterWordC); - while (maskC == 0) { - hashC ^= delimiterWordC; - cursorC += 8; - delimiterWordC = UNSAFE.getLong(cursorC); - maskC = getDelimiterMask(delimiterWordC); - } - final int delimiterByteC = Long.numberOfTrailingZeros(maskC); - final long semicolonC = cursorC + (delimiterByteC >> 3); - final long maskedWordC = delimiterWordC & ((maskC >>> 7) - 1); - hashC ^= maskedWordC; - int intHashC = (int) (hashC ^ (hashC >> 32)); - intHashC = intHashC ^ (intHashC >> 17); - - long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, intHashC, maskedWordC); - cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC); - } - // System.out.println("done C"); - while (cursorD < endD) { - long startD = cursorD; - long delimiterWordD = UNSAFE.getLong(cursorD); - long hashD = 0; - maskD = getDelimiterMask(delimiterWordD); - while (maskD == 0) { - hashD ^= delimiterWordD; - cursorD += 8; - delimiterWordD = UNSAFE.getLong(cursorD); - maskD = getDelimiterMask(delimiterWordD); - } - final int delimiterByteD = Long.numberOfTrailingZeros(maskD); - final long semicolonD = cursorD + (delimiterByteD >> 3); - final long maskedWordD = delimiterWordD & ((maskD >>> 7) - 1); - hashD ^= maskedWordD; - int intHashD = (int) (hashD ^ (hashD >> 32)); - intHashD = intHashD ^ (intHashD >> 17); - - long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, intHashD, maskedWordD); - cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD); - } - // System.out.println("done D"); } @Override @@ -359,10 +327,14 @@ public class CalculateAverage_jerrinot { long startC = cursorC; long startD = cursorD; - long delimiterWordA = UNSAFE.getLong(cursorA); - long delimiterWordB = UNSAFE.getLong(cursorB); - long delimiterWordC = UNSAFE.getLong(cursorC); - long delimiterWordD = UNSAFE.getLong(cursorD); + long currentWordA = UNSAFE.getLong(startA); + // long delimiterWordA2 = UNSAFE.getLong(startA + 8); + long currentWordB = UNSAFE.getLong(startB); + // long delimiterWordB2 = UNSAFE.getLong(startB + 8); + long currentWordC = UNSAFE.getLong(startC); + // long delimiterWordCa = UNSAFE.getLong(startC + 8); + long currentWordD = UNSAFE.getLong(startD); + // long delimiterWordD2 = UNSAFE.getLong(startD + 8); long hashA = 0; long hashB = 0; @@ -370,58 +342,62 @@ public class CalculateAverage_jerrinot { long hashD = 0; // credits for the hashing idea: royvanrijn - maskA = getDelimiterMask(delimiterWordA); + long maskA = getDelimiterMask(currentWordA); while (maskA == 0) { - hashA ^= delimiterWordA; + hashA ^= currentWordA; cursorA += 8; - delimiterWordA = UNSAFE.getLong(cursorA); - maskA = getDelimiterMask(delimiterWordA); + currentWordA = UNSAFE.getLong(cursorA); + maskA = getDelimiterMask(currentWordA); } final int delimiterByteA = Long.numberOfTrailingZeros(maskA); final long semicolonA = cursorA + (delimiterByteA >> 3); - final long maskedWordA = delimiterWordA & ((maskA >>> 7) - 1); + long temperatureWordA = UNSAFE.getLong(semicolonA + 1); + final long maskedWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8; hashA ^= maskedWordA; int intHashA = (int) (hashA ^ (hashA >> 32)); intHashA = intHashA ^ (intHashA >> 17); - maskB = getDelimiterMask(delimiterWordB); + long maskB = getDelimiterMask(currentWordB); while (maskB == 0) { - hashB ^= delimiterWordB; + hashB ^= currentWordB; cursorB += 8; - delimiterWordB = UNSAFE.getLong(cursorB); - maskB = getDelimiterMask(delimiterWordB); + currentWordB = UNSAFE.getLong(cursorB); + maskB = getDelimiterMask(currentWordB); } final int delimiterByteB = Long.numberOfTrailingZeros(maskB); final long semicolonB = cursorB + (delimiterByteB >> 3); - final long maskedWordB = delimiterWordB & ((maskB >>> 7) - 1); + long temperatureWordB = UNSAFE.getLong(semicolonB + 1); + final long maskedWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8; hashB ^= maskedWordB; int intHashB = (int) (hashB ^ (hashB >> 32)); intHashB = intHashB ^ (intHashB >> 17); - maskC = getDelimiterMask(delimiterWordC); + long maskC = getDelimiterMask(currentWordC); while (maskC == 0) { - hashC ^= delimiterWordC; + hashC ^= currentWordC; cursorC += 8; - delimiterWordC = UNSAFE.getLong(cursorC); - maskC = getDelimiterMask(delimiterWordC); + currentWordC = UNSAFE.getLong(cursorC); + maskC = getDelimiterMask(currentWordC); } final int delimiterByteC = Long.numberOfTrailingZeros(maskC); final long semicolonC = cursorC + (delimiterByteC >> 3); - final long maskedWordC = delimiterWordC & ((maskC >>> 7) - 1); + long temperatureWordC = UNSAFE.getLong(semicolonC + 1); + final long maskedWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8; hashC ^= maskedWordC; int intHashC = (int) (hashC ^ (hashC >> 32)); intHashC = intHashC ^ (intHashC >> 17); - maskD = getDelimiterMask(delimiterWordD); + long maskD = getDelimiterMask(currentWordD); while (maskD == 0) { - hashD ^= delimiterWordD; + hashD ^= currentWordD; cursorD += 8; - delimiterWordD = UNSAFE.getLong(cursorD); - maskD = getDelimiterMask(delimiterWordD); + currentWordD = UNSAFE.getLong(cursorD); + maskD = getDelimiterMask(currentWordD); } final int delimiterByteD = Long.numberOfTrailingZeros(maskD); final long semicolonD = cursorD + (delimiterByteD >> 3); - final long maskedWordD = delimiterWordD & ((maskD >>> 7) - 1); + long temperatureWordD = UNSAFE.getLong(semicolonD + 1); + final long maskedWordD = currentWordD & ((maskD - 1) ^ maskD) >>> 8; hashD ^= maskedWordD; int intHashD = (int) (hashD ^ (hashD >> 32)); intHashD = intHashD ^ (intHashD >> 17); @@ -431,52 +407,75 @@ public class CalculateAverage_jerrinot { long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, intHashC, maskedWordC); long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, intHashD, maskedWordD); - cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA); - cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB); - cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC); - cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD); + cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA, temperatureWordA); + cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB, temperatureWordB); + cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC, temperatureWordC); + cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD, temperatureWordD); } doTail(); } private long getOrCreateEntryBaseOffset(long semicolonA, long startA, int intHashA, long maskedWordA) { - int lenA = (int) (semicolonA - startA); + // hashSet.add(intHashA); + long lenLong = semicolonA - startA; + int lenA = (int) lenLong; + + // assert lenA != 0; + // byte[] nameArr = new byte[lenA]; + // for (int i = 0; i < lenA; i++) { + // nameArr[i] = UNSAFE.getByte(startA + i); + // } + // String nameStr = new String(nameArr); + // Integer oldHash = nameToHash.put(nameStr, intHashA); + // assert oldHash == null || oldHash == intHashA : "name: " + nameStr + ", old hash = " + oldHash + ", new hash = " + intHashA; + long mapIndexA = intHashA & MAP_MASK; + // long clusterLen = 0; for (;;) { long basePtr = mapIndexA * MAP_ENTRY_SIZE_BYTES + map; long lenPtr = basePtr + LEN_OFFSET; int len = UNSAFE.getInt(lenPtr); - if (len == 0) { + if (len == lenA) { + if (nameMatch(startA, maskedWordA, basePtr, lenLong)) { + // if (clusterLen > maxClusterLen) { + // maxClusterLen = clusterLen; + // System.out.println("max cluster len: " + clusterLen); + // } + return basePtr; + } + } + else if (len == 0) { // todo: uncommon branch maybe? // empty slot UNSAFE.copyMemory(semicolonA - lenA, basePtr + NAME_OFFSET, lenA); UNSAFE.putInt(lenPtr, lenA); + // todo: this could be a single putLong() UNSAFE.putInt(basePtr + MAX_OFFSET, Integer.MIN_VALUE); UNSAFE.putInt(basePtr + MIN_OFFSET, Integer.MAX_VALUE); return basePtr; } - if (len == lenA) { - boolean match = true; - long namePtr = basePtr + NAME_OFFSET; - int fullLen = (len >> 3) << 3; - long offset; - // todo: this is worth exploring further. - // @mtopolnik has an interesting algo with 2 unconditioned long loads: this is sufficient - // for majority of names. so we would be left with just a single branch which is almost never taken? - for (offset = 0; offset < fullLen; offset += 8) { - match &= (UNSAFE.getLong(startA + offset) == UNSAFE.getLong(namePtr + offset)); - } - - long maskedWordInMap = UNSAFE.getLong(namePtr + offset); - match &= (maskedWordInMap == maskedWordA); - - if (match) { - return basePtr; - } - } mapIndexA = ++mapIndexA & MAP_MASK; + // clusterLen++; } } + + private static boolean nameMatch(long startA, long maskedWordA, long basePtr, long len) { + long namePtr = basePtr + NAME_OFFSET; + long fullLen = len & ~7L; + long offset; + + // todo: this is worth exploring further. + // @mtopolnik has an interesting algo with 2 unconditioned long loads: this is sufficient + // for majority of names. so we would be left with just a single branch which is almost never taken? + for (offset = 0; offset < fullLen; offset += 8) { + if (UNSAFE.getLong(startA + offset) != UNSAFE.getLong(namePtr + offset)) { + return false; + } + } + + long maskedWordInMap = UNSAFE.getLong(namePtr + fullLen); + return (maskedWordInMap == maskedWordA); + } } }