jerrinot's improvement (#607)

* some random changes with minimal, if any, effect

* use munmap() trick
credit: thomaswue

* some smaller tweaks

* use native image
This commit is contained in:
Jaromir Hamala 2024-01-28 11:34:28 +01:00 committed by GitHub
parent a6cd83fc98
commit d9ab36a241
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 175 additions and 106 deletions

View File

@ -17,5 +17,11 @@
# -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_mtopolnik::recordMeasurementAndAdvanceCursor" # -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_mtopolnik::recordMeasurementAndAdvanceCursor"
# -XX:InlineSmallCode=10000 -XX:-TieredCompilation -XX:CICompilerCount=2 -XX:CompileThreshold=1000\ # -XX:InlineSmallCode=10000 -XX:-TieredCompilation -XX:CICompilerCount=2 -XX:CompileThreshold=1000\
java -XX:+UseParallelGC --enable-preview \ if [ -f target/CalculateAverage_jerrinot_image ]; then
--class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jerrinot echo "Picking up existing native image 'target/CalculateAverage_jerrinot_image', delete the file to select JVM mode." 1>&2
target/CalculateAverage_jerrinot_image
else
JAVA_OPTS="--enable-preview"
echo "Choosing to run the app in JVM mode as no native image was found, use prepare_jerrinot.sh to generate." 1>&2
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jerrinot
fi

View File

@ -16,4 +16,11 @@
# #
source "$HOME/.sdkman/bin/sdkman-init.sh" source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.1-graal 1>&2 sdk use java 21.0.2-graal 1>&2
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
if [ ! -f target/CalculateAverage_jerrinot_image ]; then
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_jerrinot"
# Use -H:MethodFilter=CalculateAverage_jerrinot.* -H:Dump=:2 -H:PrintGraph=Network for IdealGraphVisualizer graph dumping.
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_jerrinot_image dev.morling.onebrc.CalculateAverage_jerrinot
fi

View File

@ -18,6 +18,7 @@ package dev.morling.onebrc;
import sun.misc.Unsafe; import sun.misc.Unsafe;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile; import java.io.RandomAccessFile;
import java.lang.foreign.Arena; import java.lang.foreign.Arena;
import java.lang.reflect.Field; import java.lang.reflect.Field;
@ -54,9 +55,29 @@ public class CalculateAverage_jerrinot {
} }
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
// credits for spawning new workers: thomaswue
if (args.length == 0 || !("--worker".equals(args[0]))) {
spawnWorker();
return;
}
calculate(); calculate();
} }
private static void spawnWorker() throws IOException {
ProcessHandle.Info info = ProcessHandle.current().info();
ArrayList<String> workerCommand = new ArrayList<>();
info.command().ifPresent(workerCommand::add);
info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args)));
workerCommand.add("--worker");
new ProcessBuilder()
.command(workerCommand)
.inheritIO()
.redirectOutput(ProcessBuilder.Redirect.PIPE)
.start()
.getInputStream()
.transferTo(System.out);
}
static void calculate() throws Exception { static void calculate() throws Exception {
final File file = new File(MEASUREMENTS_TXT); final File file = new File(MEASUREMENTS_TXT);
final long length = file.length(); final long length = file.length();
@ -140,6 +161,7 @@ public class CalculateAverage_jerrinot {
} }
sb.append('}'); sb.append('}');
System.out.println(sb); System.out.println(sb);
System.out.close();
} }
public static int ceilPow2(int i) { public static int ceilPow2(int i) {
@ -187,7 +209,7 @@ public class CalculateAverage_jerrinot {
private static final int SLOW_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * SLOW_MAP_ENTRY_SIZE_BYTES; private static final int SLOW_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * SLOW_MAP_ENTRY_SIZE_BYTES;
private static final int FAST_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * FAST_MAP_ENTRY_SIZE_BYTES; private static final int FAST_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * FAST_MAP_ENTRY_SIZE_BYTES;
private static final int SLOW_MAP_MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES; private static final int SLOW_MAP_MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES;
private static final long MAP_MASK = MAPS_SLOT_COUNT - 1; private static final int MAP_MASK = MAPS_SLOT_COUNT - 1;
private long slowMap; private long slowMap;
private long slowMapNamesPtr; private long slowMapNamesPtr;
@ -281,9 +303,9 @@ public class CalculateAverage_jerrinot {
doOne(cursorC, endC); doOne(cursorC, endC);
transferToHeap(); transferToHeap();
UNSAFE.freeMemory(fastMap); // UNSAFE.freeMemory(fastMap);
UNSAFE.freeMemory(slowMap); // UNSAFE.freeMemory(slowMap);
UNSAFE.freeMemory(slowMapNamesLo); // UNSAFE.freeMemory(slowMapNamesLo);
} }
private void transferToHeap() { private void transferToHeap() {
@ -339,11 +361,11 @@ public class CalculateAverage_jerrinot {
long mask = getDelimiterMask(currentWord); long mask = getDelimiterMask(currentWord);
long firstWordMask = ((mask - 1) ^ mask) >>> 8; long firstWordMask = ((mask - 1) ^ mask) >>> 8;
final long isMaskZeroA = ((mask | -mask) >>> 63) ^ 1; final long isMaskZeroA = ((mask | -mask) >>> 63) ^ 1;
long ext = -isMaskZeroA & 0xFF00_0000_0000_0000L; long ext = -isMaskZeroA;
firstWordMask |= ext; firstWordMask |= ext;
long maskedFirstWord = currentWord & firstWordMask; long maskedFirstWord = currentWord & firstWordMask;
long hash = hash(maskedFirstWord); int hash = hash(maskedFirstWord);
while (mask == 0) { while (mask == 0) {
cursor += 8; cursor += 8;
currentWord = UNSAFE.getLong(cursor); currentWord = UNSAFE.getLong(cursor);
@ -353,22 +375,22 @@ public class CalculateAverage_jerrinot {
final long semicolon = cursor + (delimiterByte >> 3); final long semicolon = cursor + (delimiterByte >> 3);
final long maskedWord = currentWord & ((mask - 1) ^ mask) >>> 8; final long maskedWord = currentWord & ((mask - 1) ^ mask) >>> 8;
long len = semicolon - start; int len = (int) (semicolon - start);
long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, (int) hash, maskedWord); long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, hash, maskedWord);
long temperatureWord = UNSAFE.getLong(semicolon + 1); long temperatureWord = UNSAFE.getLong(semicolon + 1);
cursor = parseAndStoreTemperature(semicolon + 1, baseEntryPtr, temperatureWord); cursor = parseAndStoreTemperature(semicolon + 1, baseEntryPtr, temperatureWord);
} }
} }
private static long hash(long word1) { private static int hash(long word) {
// credit: mtopolnik // credit: mtopolnik
long seed = 0x51_7c_c1_b7_27_22_0a_95L; long seed = 0x51_7c_c1_b7_27_22_0a_95L;
int rotDist = 17; int rotDist = 17;
//
long hash = word1; long hash = word;
hash *= seed; hash *= seed;
hash = Long.rotateLeft(hash, rotDist); hash = Long.rotateLeft(hash, rotDist);
return hash; return (int) hash;
} }
@Override @Override
@ -382,69 +404,87 @@ public class CalculateAverage_jerrinot {
UNSAFE.setMemory(slowMapNamesPtr, SLOW_MAP_MAP_NAMES_BYTES, (byte) 0); UNSAFE.setMemory(slowMapNamesPtr, SLOW_MAP_MAP_NAMES_BYTES, (byte) 0);
while (cursorA < endA && cursorB < endB && cursorC < endC) { while (cursorA < endA && cursorB < endB && cursorC < endC) {
long currentWordA = UNSAFE.getLong(cursorA);
long currentWordB = UNSAFE.getLong(cursorB);
long currentWordC = UNSAFE.getLong(cursorC);
long startA = cursorA; long startA = cursorA;
long startB = cursorB; long startB = cursorB;
long startC = cursorC; long startC = cursorC;
long currentWordA = UNSAFE.getLong(startA);
long currentWordB = UNSAFE.getLong(startB);
long currentWordC = UNSAFE.getLong(startC);
long maskA = getDelimiterMask(currentWordA); long maskA = getDelimiterMask(currentWordA);
long maskB = getDelimiterMask(currentWordB); long maskB = getDelimiterMask(currentWordB);
long maskC = getDelimiterMask(currentWordC); long maskC = getDelimiterMask(currentWordC);
long firstWordMaskA = (maskA ^ (maskA - 1)) >>> 8; long maskComplementA = -maskA;
long firstWordMaskB = (maskB ^ (maskB - 1)) >>> 8; long maskComplementB = -maskB;
long firstWordMaskC = (maskC ^ (maskC - 1)) >>> 8; long maskComplementC = -maskC;
final long isMaskZeroA = ((maskA | -maskA) >>> 63) ^ 1; long maskWithDelimiterA = (maskA ^ (maskA - 1));
final long isMaskZeroB = ((maskB | -maskB) >>> 63) ^ 1; long maskWithDelimiterB = (maskB ^ (maskB - 1));
final long isMaskZeroC = ((maskC | -maskC) >>> 63) ^ 1; long maskWithDelimiterC = (maskC ^ (maskC - 1));
long extA = -isMaskZeroA & 0xFF00_0000_0000_0000L; long isMaskZeroA = (((maskA | maskComplementA) >>> 63) ^ 1);
long extB = -isMaskZeroB & 0xFF00_0000_0000_0000L; long isMaskZeroB = (((maskB | maskComplementB) >>> 63) ^ 1);
long extC = -isMaskZeroC & 0xFF00_0000_0000_0000L; long isMaskZeroC = (((maskC | maskComplementC) >>> 63) ^ 1);
firstWordMaskA |= extA; cursorA += isMaskZeroA << 3;
firstWordMaskB |= extB; cursorB += isMaskZeroB << 3;
firstWordMaskC |= extC; cursorC += isMaskZeroC << 3;
long maskedFirstWordA = currentWordA & firstWordMaskA; long nextWordA = UNSAFE.getLong(cursorA);
long maskedFirstWordB = currentWordB & firstWordMaskB; long nextWordB = UNSAFE.getLong(cursorB);
long maskedFirstWordC = currentWordC & firstWordMaskC; long nextWordC = UNSAFE.getLong(cursorC);
// assertMasks(isMaskZeroA, maskA); long firstWordMaskA = maskWithDelimiterA >>> 8;
long firstWordMaskB = maskWithDelimiterB >>> 8;
long firstWordMaskC = maskWithDelimiterC >>> 8;
long hashA = hash(maskedFirstWordA); long nextMaskA = getDelimiterMask(nextWordA);
long hashB = hash(maskedFirstWordB); long nextMaskB = getDelimiterMask(nextWordB);
long hashC = hash(maskedFirstWordC); long nextMaskC = getDelimiterMask(nextWordC);
cursorA += isMaskZeroA * 8; boolean slowA = nextMaskA == 0;
cursorB += isMaskZeroB * 8; boolean slowB = nextMaskB == 0;
cursorC += isMaskZeroC * 8; boolean slowC = nextMaskC == 0;
boolean slowSome = (slowA || slowB || slowC);
currentWordA = UNSAFE.getLong(cursorA); long extA = -isMaskZeroA;
currentWordB = UNSAFE.getLong(cursorB); long extB = -isMaskZeroB;
currentWordC = UNSAFE.getLong(cursorC); long extC = -isMaskZeroC;
maskA = getDelimiterMask(currentWordA); long maskedFirstWordA = (extA | firstWordMaskA) & currentWordA;
while (maskA == 0) { long maskedFirstWordB = (extB | firstWordMaskB) & currentWordB;
cursorA += 8; long maskedFirstWordC = (extC | firstWordMaskC) & currentWordC;
currentWordA = UNSAFE.getLong(cursorA);
maskA = getDelimiterMask(currentWordA); int hashA = hash(maskedFirstWordA);
} int hashB = hash(maskedFirstWordB);
maskB = getDelimiterMask(currentWordB); int hashC = hash(maskedFirstWordC);
while (maskB == 0) {
cursorB += 8; currentWordA = nextWordA;
currentWordB = UNSAFE.getLong(cursorB); currentWordB = nextWordB;
maskB = getDelimiterMask(currentWordB); currentWordC = nextWordC;
}
maskC = getDelimiterMask(currentWordC); maskA = nextMaskA;
while (maskC == 0) { maskB = nextMaskB;
cursorC += 8; maskC = nextMaskC;
currentWordC = UNSAFE.getLong(cursorC); if (slowSome) {
maskC = getDelimiterMask(currentWordC); while (maskA == 0) {
cursorA += 8;
currentWordA = UNSAFE.getLong(cursorA);
maskA = getDelimiterMask(currentWordA);
}
while (maskB == 0) {
cursorB += 8;
currentWordB = UNSAFE.getLong(cursorB);
maskB = getDelimiterMask(currentWordB);
}
while (maskC == 0) {
cursorC += 8;
currentWordC = UNSAFE.getLong(cursorC);
maskC = getDelimiterMask(currentWordC);
}
} }
final int delimiterByteA = Long.numberOfTrailingZeros(maskA); final int delimiterByteA = Long.numberOfTrailingZeros(maskA);
@ -458,40 +498,57 @@ public class CalculateAverage_jerrinot {
long digitStartA = semicolonA + 1; long digitStartA = semicolonA + 1;
long digitStartB = semicolonB + 1; long digitStartB = semicolonB + 1;
long digitStartC = semicolonC + 1; long digitStartC = semicolonC + 1;
long temperatureWordA = UNSAFE.getLong(digitStartA); long temperatureWordA = UNSAFE.getLong(digitStartA);
long temperatureWordB = UNSAFE.getLong(digitStartB); long temperatureWordB = UNSAFE.getLong(digitStartB);
long temperatureWordC = UNSAFE.getLong(digitStartC); long temperatureWordC = UNSAFE.getLong(digitStartC);
final long maskedWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8; long lastWordMaskA = ((maskA - 1) ^ maskA) >>> 8;
final long maskedWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8; long lastWordMaskB = ((maskB - 1) ^ maskB) >>> 8;
final long maskedWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8; long lastWordMaskC = ((maskC - 1) ^ maskC) >>> 8;
long lenA = semicolonA - startA; final long maskedLastWordA = currentWordA & lastWordMaskA;
long lenB = semicolonB - startB; final long maskedLastWordB = currentWordB & lastWordMaskB;
long lenC = semicolonC - startC; final long maskedLastWordC = currentWordC & lastWordMaskC;
int lenA = (int) (semicolonA - startA);
int lenB = (int) (semicolonB - startB);
int lenC = (int) (semicolonC - startC);
int mapIndexA = hashA & MAP_MASK;
int mapIndexB = hashB & MAP_MASK;
int mapIndexC = hashC & MAP_MASK;
long baseEntryPtrA; long baseEntryPtrA;
if (lenA > 15) {
baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, (int) hashA, maskedWordA);
}
else {
baseEntryPtrA = getOrCreateEntryBaseOffsetFast(lenA, (int) hashA, maskedWordA, maskedFirstWordA);
}
long baseEntryPtrB; long baseEntryPtrB;
if (lenB > 15) {
baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, (int) hashB, maskedWordB);
}
else {
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(lenB, (int) hashB, maskedWordB, maskedFirstWordB);
}
long baseEntryPtrC; long baseEntryPtrC;
if (lenC > 15) {
baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, (int) hashC, maskedWordC); if (slowSome) {
if (slowA) {
baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, hashA, maskedLastWordA);
}
else {
baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA);
}
if (slowB) {
baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, hashB, maskedLastWordB);
}
else {
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB);
}
if (slowC) {
baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, hashC, maskedLastWordC);
}
else {
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC);
}
} }
else { else {
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(lenC, (int) hashC, maskedWordC, maskedFirstWordC); baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA);
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB);
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC);
} }
cursorA = parseAndStoreTemperature(digitStartA, baseEntryPtrA, temperatureWordA); cursorA = parseAndStoreTemperature(digitStartA, baseEntryPtrA, temperatureWordA);
@ -502,36 +559,35 @@ public class CalculateAverage_jerrinot {
// System.out.println("Longest chain: " + longestChain); // System.out.println("Longest chain: " + longestChain);
} }
private long getOrCreateEntryBaseOffsetFast(long lenLong, int hash, long maskedLastWord, long maskedFirstWord) { private long getOrCreateEntryBaseOffsetFast(int mapIndexA, int lenA, long maskedLastWord, long maskedFirstWord) {
int lenA = (int) lenLong;
long mapIndexA = hash & MAP_MASK;
for (;;) { for (;;) {
long basePtr = mapIndexA * FAST_MAP_ENTRY_SIZE_BYTES + fastMap; long basePtr = mapIndexA * FAST_MAP_ENTRY_SIZE_BYTES + fastMap;
long namePart1 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART1);
long namePart2 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART2);
if (namePart1 == maskedFirstWord && namePart2 == maskedLastWord) {
return basePtr;
}
long lenPtr = basePtr + MAP_LEN_OFFSET; long lenPtr = basePtr + MAP_LEN_OFFSET;
int len = UNSAFE.getInt(lenPtr); int len = UNSAFE.getInt(lenPtr);
if (len == lenA) { if (len == 0) {
long namePart1 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART1); return newEntryFast(lenA, maskedLastWord, maskedFirstWord, lenPtr, basePtr);
long namePart2 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART2);
if (namePart1 == maskedFirstWord && namePart2 == maskedLastWord) {
return basePtr;
}
}
else if (len == 0) {
UNSAFE.putInt(lenPtr, lenA);
// todo: this could be a single putLong()
UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE);
UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE);
UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART1, maskedFirstWord);
UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART2, maskedLastWord);
return basePtr;
} }
mapIndexA = ++mapIndexA & MAP_MASK; mapIndexA = ++mapIndexA & MAP_MASK;
} }
} }
private long getOrCreateEntryBaseOffsetSlow(long lenLong, long startPtr, int hash, long maskedLastWord) { private static long newEntryFast(int lenA, long maskedLastWord, long maskedFirstWord, long lenPtr, long basePtr) {
long fullLen = lenLong & ~7L; UNSAFE.putInt(lenPtr, lenA);
int lenA = (int) lenLong; // todo: this could be a single putLong()
UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE);
UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE);
UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART1, maskedFirstWord);
UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART2, maskedLastWord);
return basePtr;
}
private long getOrCreateEntryBaseOffsetSlow(int lenA, long startPtr, int hash, long maskedLastWord) {
long fullLen = lenA & ~7L;
long mapIndexA = hash & MAP_MASK; long mapIndexA = hash & MAP_MASK;
for (;;) { for (;;) {
long basePtr = mapIndexA * SLOW_MAP_ENTRY_SIZE_BYTES + slowMap; long basePtr = mapIndexA * SLOW_MAP_ENTRY_SIZE_BYTES + slowMap;
@ -550,7 +606,7 @@ public class CalculateAverage_jerrinot {
UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE); UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE);
UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE); UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE);
UNSAFE.copyMemory(startPtr, slowMapNamesPtr, lenA); UNSAFE.copyMemory(startPtr, slowMapNamesPtr, lenA);
long alignedLen = (lenLong & ~7L) + 8; long alignedLen = (lenA & ~7L) + 8;
slowMapNamesPtr += alignedLen; slowMapNamesPtr += alignedLen;
return basePtr; return basePtr;
} }