jerrinot's improvement (#607)

* some random changes with minimal, if any, effect

* use munmap() trick
credit: thomaswue

* some smaller tweaks

* use native image
This commit is contained in:
Jaromir Hamala 2024-01-28 11:34:28 +01:00 committed by GitHub
parent a6cd83fc98
commit d9ab36a241
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 175 additions and 106 deletions

View File

@ -17,5 +17,11 @@
# -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_mtopolnik::recordMeasurementAndAdvanceCursor"
# -XX:InlineSmallCode=10000 -XX:-TieredCompilation -XX:CICompilerCount=2 -XX:CompileThreshold=1000\
java -XX:+UseParallelGC --enable-preview \
--class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jerrinot
if [ -f target/CalculateAverage_jerrinot_image ]; then
echo "Picking up existing native image 'target/CalculateAverage_jerrinot_image', delete the file to select JVM mode." 1>&2
target/CalculateAverage_jerrinot_image
else
JAVA_OPTS="--enable-preview"
echo "Choosing to run the app in JVM mode as no native image was found, use prepare_jerrinot.sh to generate." 1>&2
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jerrinot
fi

View File

@ -16,4 +16,11 @@
#
source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.1-graal 1>&2
sdk use java 21.0.2-graal 1>&2
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
if [ ! -f target/CalculateAverage_jerrinot_image ]; then
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_jerrinot"
# Use -H:MethodFilter=CalculateAverage_jerrinot.* -H:Dump=:2 -H:PrintGraph=Network for IdealGraphVisualizer graph dumping.
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_jerrinot_image dev.morling.onebrc.CalculateAverage_jerrinot
fi

View File

@ -18,6 +18,7 @@ package dev.morling.onebrc;
import sun.misc.Unsafe;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
@ -54,9 +55,29 @@ public class CalculateAverage_jerrinot {
}
public static void main(String[] args) throws Exception {
// credits for spawning new workers: thomaswue
if (args.length == 0 || !("--worker".equals(args[0]))) {
spawnWorker();
return;
}
calculate();
}
private static void spawnWorker() throws IOException {
ProcessHandle.Info info = ProcessHandle.current().info();
ArrayList<String> workerCommand = new ArrayList<>();
info.command().ifPresent(workerCommand::add);
info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args)));
workerCommand.add("--worker");
new ProcessBuilder()
.command(workerCommand)
.inheritIO()
.redirectOutput(ProcessBuilder.Redirect.PIPE)
.start()
.getInputStream()
.transferTo(System.out);
}
static void calculate() throws Exception {
final File file = new File(MEASUREMENTS_TXT);
final long length = file.length();
@ -140,6 +161,7 @@ public class CalculateAverage_jerrinot {
}
sb.append('}');
System.out.println(sb);
System.out.close();
}
public static int ceilPow2(int i) {
@ -187,7 +209,7 @@ public class CalculateAverage_jerrinot {
private static final int SLOW_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * SLOW_MAP_ENTRY_SIZE_BYTES;
private static final int FAST_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * FAST_MAP_ENTRY_SIZE_BYTES;
private static final int SLOW_MAP_MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES;
private static final long MAP_MASK = MAPS_SLOT_COUNT - 1;
private static final int MAP_MASK = MAPS_SLOT_COUNT - 1;
private long slowMap;
private long slowMapNamesPtr;
@ -281,9 +303,9 @@ public class CalculateAverage_jerrinot {
doOne(cursorC, endC);
transferToHeap();
UNSAFE.freeMemory(fastMap);
UNSAFE.freeMemory(slowMap);
UNSAFE.freeMemory(slowMapNamesLo);
// UNSAFE.freeMemory(fastMap);
// UNSAFE.freeMemory(slowMap);
// UNSAFE.freeMemory(slowMapNamesLo);
}
private void transferToHeap() {
@ -339,11 +361,11 @@ public class CalculateAverage_jerrinot {
long mask = getDelimiterMask(currentWord);
long firstWordMask = ((mask - 1) ^ mask) >>> 8;
final long isMaskZeroA = ((mask | -mask) >>> 63) ^ 1;
long ext = -isMaskZeroA & 0xFF00_0000_0000_0000L;
long ext = -isMaskZeroA;
firstWordMask |= ext;
long maskedFirstWord = currentWord & firstWordMask;
long hash = hash(maskedFirstWord);
int hash = hash(maskedFirstWord);
while (mask == 0) {
cursor += 8;
currentWord = UNSAFE.getLong(cursor);
@ -353,22 +375,22 @@ public class CalculateAverage_jerrinot {
final long semicolon = cursor + (delimiterByte >> 3);
final long maskedWord = currentWord & ((mask - 1) ^ mask) >>> 8;
long len = semicolon - start;
long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, (int) hash, maskedWord);
int len = (int) (semicolon - start);
long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, hash, maskedWord);
long temperatureWord = UNSAFE.getLong(semicolon + 1);
cursor = parseAndStoreTemperature(semicolon + 1, baseEntryPtr, temperatureWord);
}
}
private static long hash(long word1) {
private static int hash(long word) {
// credit: mtopolnik
long seed = 0x51_7c_c1_b7_27_22_0a_95L;
int rotDist = 17;
long hash = word1;
//
long hash = word;
hash *= seed;
hash = Long.rotateLeft(hash, rotDist);
return hash;
return (int) hash;
}
@Override
@ -382,70 +404,88 @@ public class CalculateAverage_jerrinot {
UNSAFE.setMemory(slowMapNamesPtr, SLOW_MAP_MAP_NAMES_BYTES, (byte) 0);
while (cursorA < endA && cursorB < endB && cursorC < endC) {
long currentWordA = UNSAFE.getLong(cursorA);
long currentWordB = UNSAFE.getLong(cursorB);
long currentWordC = UNSAFE.getLong(cursorC);
long startA = cursorA;
long startB = cursorB;
long startC = cursorC;
long currentWordA = UNSAFE.getLong(startA);
long currentWordB = UNSAFE.getLong(startB);
long currentWordC = UNSAFE.getLong(startC);
long maskA = getDelimiterMask(currentWordA);
long maskB = getDelimiterMask(currentWordB);
long maskC = getDelimiterMask(currentWordC);
long firstWordMaskA = (maskA ^ (maskA - 1)) >>> 8;
long firstWordMaskB = (maskB ^ (maskB - 1)) >>> 8;
long firstWordMaskC = (maskC ^ (maskC - 1)) >>> 8;
long maskComplementA = -maskA;
long maskComplementB = -maskB;
long maskComplementC = -maskC;
final long isMaskZeroA = ((maskA | -maskA) >>> 63) ^ 1;
final long isMaskZeroB = ((maskB | -maskB) >>> 63) ^ 1;
final long isMaskZeroC = ((maskC | -maskC) >>> 63) ^ 1;
long maskWithDelimiterA = (maskA ^ (maskA - 1));
long maskWithDelimiterB = (maskB ^ (maskB - 1));
long maskWithDelimiterC = (maskC ^ (maskC - 1));
long extA = -isMaskZeroA & 0xFF00_0000_0000_0000L;
long extB = -isMaskZeroB & 0xFF00_0000_0000_0000L;
long extC = -isMaskZeroC & 0xFF00_0000_0000_0000L;
long isMaskZeroA = (((maskA | maskComplementA) >>> 63) ^ 1);
long isMaskZeroB = (((maskB | maskComplementB) >>> 63) ^ 1);
long isMaskZeroC = (((maskC | maskComplementC) >>> 63) ^ 1);
firstWordMaskA |= extA;
firstWordMaskB |= extB;
firstWordMaskC |= extC;
cursorA += isMaskZeroA << 3;
cursorB += isMaskZeroB << 3;
cursorC += isMaskZeroC << 3;
long maskedFirstWordA = currentWordA & firstWordMaskA;
long maskedFirstWordB = currentWordB & firstWordMaskB;
long maskedFirstWordC = currentWordC & firstWordMaskC;
long nextWordA = UNSAFE.getLong(cursorA);
long nextWordB = UNSAFE.getLong(cursorB);
long nextWordC = UNSAFE.getLong(cursorC);
// assertMasks(isMaskZeroA, maskA);
long firstWordMaskA = maskWithDelimiterA >>> 8;
long firstWordMaskB = maskWithDelimiterB >>> 8;
long firstWordMaskC = maskWithDelimiterC >>> 8;
long hashA = hash(maskedFirstWordA);
long hashB = hash(maskedFirstWordB);
long hashC = hash(maskedFirstWordC);
long nextMaskA = getDelimiterMask(nextWordA);
long nextMaskB = getDelimiterMask(nextWordB);
long nextMaskC = getDelimiterMask(nextWordC);
cursorA += isMaskZeroA * 8;
cursorB += isMaskZeroB * 8;
cursorC += isMaskZeroC * 8;
boolean slowA = nextMaskA == 0;
boolean slowB = nextMaskB == 0;
boolean slowC = nextMaskC == 0;
boolean slowSome = (slowA || slowB || slowC);
currentWordA = UNSAFE.getLong(cursorA);
currentWordB = UNSAFE.getLong(cursorB);
currentWordC = UNSAFE.getLong(cursorC);
long extA = -isMaskZeroA;
long extB = -isMaskZeroB;
long extC = -isMaskZeroC;
maskA = getDelimiterMask(currentWordA);
long maskedFirstWordA = (extA | firstWordMaskA) & currentWordA;
long maskedFirstWordB = (extB | firstWordMaskB) & currentWordB;
long maskedFirstWordC = (extC | firstWordMaskC) & currentWordC;
int hashA = hash(maskedFirstWordA);
int hashB = hash(maskedFirstWordB);
int hashC = hash(maskedFirstWordC);
currentWordA = nextWordA;
currentWordB = nextWordB;
currentWordC = nextWordC;
maskA = nextMaskA;
maskB = nextMaskB;
maskC = nextMaskC;
if (slowSome) {
while (maskA == 0) {
cursorA += 8;
currentWordA = UNSAFE.getLong(cursorA);
maskA = getDelimiterMask(currentWordA);
}
maskB = getDelimiterMask(currentWordB);
while (maskB == 0) {
cursorB += 8;
currentWordB = UNSAFE.getLong(cursorB);
maskB = getDelimiterMask(currentWordB);
}
maskC = getDelimiterMask(currentWordC);
while (maskC == 0) {
cursorC += 8;
currentWordC = UNSAFE.getLong(cursorC);
maskC = getDelimiterMask(currentWordC);
}
}
final int delimiterByteA = Long.numberOfTrailingZeros(maskA);
final int delimiterByteB = Long.numberOfTrailingZeros(maskB);
@ -458,40 +498,57 @@ public class CalculateAverage_jerrinot {
long digitStartA = semicolonA + 1;
long digitStartB = semicolonB + 1;
long digitStartC = semicolonC + 1;
long temperatureWordA = UNSAFE.getLong(digitStartA);
long temperatureWordB = UNSAFE.getLong(digitStartB);
long temperatureWordC = UNSAFE.getLong(digitStartC);
final long maskedWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8;
final long maskedWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8;
final long maskedWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8;
long lastWordMaskA = ((maskA - 1) ^ maskA) >>> 8;
long lastWordMaskB = ((maskB - 1) ^ maskB) >>> 8;
long lastWordMaskC = ((maskC - 1) ^ maskC) >>> 8;
long lenA = semicolonA - startA;
long lenB = semicolonB - startB;
long lenC = semicolonC - startC;
final long maskedLastWordA = currentWordA & lastWordMaskA;
final long maskedLastWordB = currentWordB & lastWordMaskB;
final long maskedLastWordC = currentWordC & lastWordMaskC;
int lenA = (int) (semicolonA - startA);
int lenB = (int) (semicolonB - startB);
int lenC = (int) (semicolonC - startC);
int mapIndexA = hashA & MAP_MASK;
int mapIndexB = hashB & MAP_MASK;
int mapIndexC = hashC & MAP_MASK;
long baseEntryPtrA;
if (lenA > 15) {
baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, (int) hashA, maskedWordA);
}
else {
baseEntryPtrA = getOrCreateEntryBaseOffsetFast(lenA, (int) hashA, maskedWordA, maskedFirstWordA);
}
long baseEntryPtrB;
if (lenB > 15) {
baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, (int) hashB, maskedWordB);
long baseEntryPtrC;
if (slowSome) {
if (slowA) {
baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, hashA, maskedLastWordA);
}
else {
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(lenB, (int) hashB, maskedWordB, maskedFirstWordB);
baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA);
}
long baseEntryPtrC;
if (lenC > 15) {
baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, (int) hashC, maskedWordC);
if (slowB) {
baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, hashB, maskedLastWordB);
}
else {
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(lenC, (int) hashC, maskedWordC, maskedFirstWordC);
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB);
}
if (slowC) {
baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, hashC, maskedLastWordC);
}
else {
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC);
}
}
else {
baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA);
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB);
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC);
}
cursorA = parseAndStoreTemperature(digitStartA, baseEntryPtrA, temperatureWordA);
@ -502,21 +559,24 @@ public class CalculateAverage_jerrinot {
// System.out.println("Longest chain: " + longestChain);
}
private long getOrCreateEntryBaseOffsetFast(long lenLong, int hash, long maskedLastWord, long maskedFirstWord) {
int lenA = (int) lenLong;
long mapIndexA = hash & MAP_MASK;
private long getOrCreateEntryBaseOffsetFast(int mapIndexA, int lenA, long maskedLastWord, long maskedFirstWord) {
for (;;) {
long basePtr = mapIndexA * FAST_MAP_ENTRY_SIZE_BYTES + fastMap;
long lenPtr = basePtr + MAP_LEN_OFFSET;
int len = UNSAFE.getInt(lenPtr);
if (len == lenA) {
long namePart1 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART1);
long namePart2 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART2);
if (namePart1 == maskedFirstWord && namePart2 == maskedLastWord) {
return basePtr;
}
long lenPtr = basePtr + MAP_LEN_OFFSET;
int len = UNSAFE.getInt(lenPtr);
if (len == 0) {
return newEntryFast(lenA, maskedLastWord, maskedFirstWord, lenPtr, basePtr);
}
else if (len == 0) {
mapIndexA = ++mapIndexA & MAP_MASK;
}
}
private static long newEntryFast(int lenA, long maskedLastWord, long maskedFirstWord, long lenPtr, long basePtr) {
UNSAFE.putInt(lenPtr, lenA);
// todo: this could be a single putLong()
UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE);
@ -525,13 +585,9 @@ public class CalculateAverage_jerrinot {
UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART2, maskedLastWord);
return basePtr;
}
mapIndexA = ++mapIndexA & MAP_MASK;
}
}
private long getOrCreateEntryBaseOffsetSlow(long lenLong, long startPtr, int hash, long maskedLastWord) {
long fullLen = lenLong & ~7L;
int lenA = (int) lenLong;
private long getOrCreateEntryBaseOffsetSlow(int lenA, long startPtr, int hash, long maskedLastWord) {
long fullLen = lenA & ~7L;
long mapIndexA = hash & MAP_MASK;
for (;;) {
long basePtr = mapIndexA * SLOW_MAP_ENTRY_SIZE_BYTES + slowMap;
@ -550,7 +606,7 @@ public class CalculateAverage_jerrinot {
UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE);
UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE);
UNSAFE.copyMemory(startPtr, slowMapNamesPtr, lenA);
long alignedLen = (lenLong & ~7L) + 8;
long alignedLen = (lenA & ~7L) + 8;
slowMapNamesPtr += alignedLen;
return basePtr;
}