jerrinot's improvement (#607)
* some random changes with minimal, if any, effect * use munmap() trick credit: thomaswue * some smaller tweaks * use native image
This commit is contained in:
parent
a6cd83fc98
commit
d9ab36a241
@ -17,5 +17,11 @@
|
||||
|
||||
# -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_mtopolnik::recordMeasurementAndAdvanceCursor"
|
||||
# -XX:InlineSmallCode=10000 -XX:-TieredCompilation -XX:CICompilerCount=2 -XX:CompileThreshold=1000\
|
||||
java -XX:+UseParallelGC --enable-preview \
|
||||
--class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jerrinot
|
||||
if [ -f target/CalculateAverage_jerrinot_image ]; then
|
||||
echo "Picking up existing native image 'target/CalculateAverage_jerrinot_image', delete the file to select JVM mode." 1>&2
|
||||
target/CalculateAverage_jerrinot_image
|
||||
else
|
||||
JAVA_OPTS="--enable-preview"
|
||||
echo "Choosing to run the app in JVM mode as no native image was found, use prepare_jerrinot.sh to generate." 1>&2
|
||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jerrinot
|
||||
fi
|
||||
|
@ -16,4 +16,11 @@
|
||||
#
|
||||
|
||||
source "$HOME/.sdkman/bin/sdkman-init.sh"
|
||||
sdk use java 21.0.1-graal 1>&2
|
||||
sdk use java 21.0.2-graal 1>&2
|
||||
|
||||
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
|
||||
if [ ! -f target/CalculateAverage_jerrinot_image ]; then
|
||||
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_jerrinot"
|
||||
# Use -H:MethodFilter=CalculateAverage_jerrinot.* -H:Dump=:2 -H:PrintGraph=Network for IdealGraphVisualizer graph dumping.
|
||||
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_jerrinot_image dev.morling.onebrc.CalculateAverage_jerrinot
|
||||
fi
|
||||
|
@ -18,6 +18,7 @@ package dev.morling.onebrc;
|
||||
import sun.misc.Unsafe;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.RandomAccessFile;
|
||||
import java.lang.foreign.Arena;
|
||||
import java.lang.reflect.Field;
|
||||
@ -54,9 +55,29 @@ public class CalculateAverage_jerrinot {
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
// credits for spawning new workers: thomaswue
|
||||
if (args.length == 0 || !("--worker".equals(args[0]))) {
|
||||
spawnWorker();
|
||||
return;
|
||||
}
|
||||
calculate();
|
||||
}
|
||||
|
||||
private static void spawnWorker() throws IOException {
|
||||
ProcessHandle.Info info = ProcessHandle.current().info();
|
||||
ArrayList<String> workerCommand = new ArrayList<>();
|
||||
info.command().ifPresent(workerCommand::add);
|
||||
info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args)));
|
||||
workerCommand.add("--worker");
|
||||
new ProcessBuilder()
|
||||
.command(workerCommand)
|
||||
.inheritIO()
|
||||
.redirectOutput(ProcessBuilder.Redirect.PIPE)
|
||||
.start()
|
||||
.getInputStream()
|
||||
.transferTo(System.out);
|
||||
}
|
||||
|
||||
static void calculate() throws Exception {
|
||||
final File file = new File(MEASUREMENTS_TXT);
|
||||
final long length = file.length();
|
||||
@ -140,6 +161,7 @@ public class CalculateAverage_jerrinot {
|
||||
}
|
||||
sb.append('}');
|
||||
System.out.println(sb);
|
||||
System.out.close();
|
||||
}
|
||||
|
||||
public static int ceilPow2(int i) {
|
||||
@ -187,7 +209,7 @@ public class CalculateAverage_jerrinot {
|
||||
private static final int SLOW_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * SLOW_MAP_ENTRY_SIZE_BYTES;
|
||||
private static final int FAST_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * FAST_MAP_ENTRY_SIZE_BYTES;
|
||||
private static final int SLOW_MAP_MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES;
|
||||
private static final long MAP_MASK = MAPS_SLOT_COUNT - 1;
|
||||
private static final int MAP_MASK = MAPS_SLOT_COUNT - 1;
|
||||
|
||||
private long slowMap;
|
||||
private long slowMapNamesPtr;
|
||||
@ -281,9 +303,9 @@ public class CalculateAverage_jerrinot {
|
||||
doOne(cursorC, endC);
|
||||
|
||||
transferToHeap();
|
||||
UNSAFE.freeMemory(fastMap);
|
||||
UNSAFE.freeMemory(slowMap);
|
||||
UNSAFE.freeMemory(slowMapNamesLo);
|
||||
// UNSAFE.freeMemory(fastMap);
|
||||
// UNSAFE.freeMemory(slowMap);
|
||||
// UNSAFE.freeMemory(slowMapNamesLo);
|
||||
}
|
||||
|
||||
private void transferToHeap() {
|
||||
@ -339,11 +361,11 @@ public class CalculateAverage_jerrinot {
|
||||
long mask = getDelimiterMask(currentWord);
|
||||
long firstWordMask = ((mask - 1) ^ mask) >>> 8;
|
||||
final long isMaskZeroA = ((mask | -mask) >>> 63) ^ 1;
|
||||
long ext = -isMaskZeroA & 0xFF00_0000_0000_0000L;
|
||||
long ext = -isMaskZeroA;
|
||||
firstWordMask |= ext;
|
||||
|
||||
long maskedFirstWord = currentWord & firstWordMask;
|
||||
long hash = hash(maskedFirstWord);
|
||||
int hash = hash(maskedFirstWord);
|
||||
while (mask == 0) {
|
||||
cursor += 8;
|
||||
currentWord = UNSAFE.getLong(cursor);
|
||||
@ -353,22 +375,22 @@ public class CalculateAverage_jerrinot {
|
||||
final long semicolon = cursor + (delimiterByte >> 3);
|
||||
final long maskedWord = currentWord & ((mask - 1) ^ mask) >>> 8;
|
||||
|
||||
long len = semicolon - start;
|
||||
long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, (int) hash, maskedWord);
|
||||
int len = (int) (semicolon - start);
|
||||
long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, hash, maskedWord);
|
||||
long temperatureWord = UNSAFE.getLong(semicolon + 1);
|
||||
cursor = parseAndStoreTemperature(semicolon + 1, baseEntryPtr, temperatureWord);
|
||||
}
|
||||
}
|
||||
|
||||
private static long hash(long word1) {
|
||||
private static int hash(long word) {
|
||||
// credit: mtopolnik
|
||||
long seed = 0x51_7c_c1_b7_27_22_0a_95L;
|
||||
int rotDist = 17;
|
||||
|
||||
long hash = word1;
|
||||
//
|
||||
long hash = word;
|
||||
hash *= seed;
|
||||
hash = Long.rotateLeft(hash, rotDist);
|
||||
return hash;
|
||||
return (int) hash;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -382,70 +404,88 @@ public class CalculateAverage_jerrinot {
|
||||
UNSAFE.setMemory(slowMapNamesPtr, SLOW_MAP_MAP_NAMES_BYTES, (byte) 0);
|
||||
|
||||
while (cursorA < endA && cursorB < endB && cursorC < endC) {
|
||||
long currentWordA = UNSAFE.getLong(cursorA);
|
||||
long currentWordB = UNSAFE.getLong(cursorB);
|
||||
long currentWordC = UNSAFE.getLong(cursorC);
|
||||
|
||||
long startA = cursorA;
|
||||
long startB = cursorB;
|
||||
long startC = cursorC;
|
||||
|
||||
long currentWordA = UNSAFE.getLong(startA);
|
||||
long currentWordB = UNSAFE.getLong(startB);
|
||||
long currentWordC = UNSAFE.getLong(startC);
|
||||
|
||||
long maskA = getDelimiterMask(currentWordA);
|
||||
long maskB = getDelimiterMask(currentWordB);
|
||||
long maskC = getDelimiterMask(currentWordC);
|
||||
|
||||
long firstWordMaskA = (maskA ^ (maskA - 1)) >>> 8;
|
||||
long firstWordMaskB = (maskB ^ (maskB - 1)) >>> 8;
|
||||
long firstWordMaskC = (maskC ^ (maskC - 1)) >>> 8;
|
||||
long maskComplementA = -maskA;
|
||||
long maskComplementB = -maskB;
|
||||
long maskComplementC = -maskC;
|
||||
|
||||
final long isMaskZeroA = ((maskA | -maskA) >>> 63) ^ 1;
|
||||
final long isMaskZeroB = ((maskB | -maskB) >>> 63) ^ 1;
|
||||
final long isMaskZeroC = ((maskC | -maskC) >>> 63) ^ 1;
|
||||
long maskWithDelimiterA = (maskA ^ (maskA - 1));
|
||||
long maskWithDelimiterB = (maskB ^ (maskB - 1));
|
||||
long maskWithDelimiterC = (maskC ^ (maskC - 1));
|
||||
|
||||
long extA = -isMaskZeroA & 0xFF00_0000_0000_0000L;
|
||||
long extB = -isMaskZeroB & 0xFF00_0000_0000_0000L;
|
||||
long extC = -isMaskZeroC & 0xFF00_0000_0000_0000L;
|
||||
long isMaskZeroA = (((maskA | maskComplementA) >>> 63) ^ 1);
|
||||
long isMaskZeroB = (((maskB | maskComplementB) >>> 63) ^ 1);
|
||||
long isMaskZeroC = (((maskC | maskComplementC) >>> 63) ^ 1);
|
||||
|
||||
firstWordMaskA |= extA;
|
||||
firstWordMaskB |= extB;
|
||||
firstWordMaskC |= extC;
|
||||
cursorA += isMaskZeroA << 3;
|
||||
cursorB += isMaskZeroB << 3;
|
||||
cursorC += isMaskZeroC << 3;
|
||||
|
||||
long maskedFirstWordA = currentWordA & firstWordMaskA;
|
||||
long maskedFirstWordB = currentWordB & firstWordMaskB;
|
||||
long maskedFirstWordC = currentWordC & firstWordMaskC;
|
||||
long nextWordA = UNSAFE.getLong(cursorA);
|
||||
long nextWordB = UNSAFE.getLong(cursorB);
|
||||
long nextWordC = UNSAFE.getLong(cursorC);
|
||||
|
||||
// assertMasks(isMaskZeroA, maskA);
|
||||
long firstWordMaskA = maskWithDelimiterA >>> 8;
|
||||
long firstWordMaskB = maskWithDelimiterB >>> 8;
|
||||
long firstWordMaskC = maskWithDelimiterC >>> 8;
|
||||
|
||||
long hashA = hash(maskedFirstWordA);
|
||||
long hashB = hash(maskedFirstWordB);
|
||||
long hashC = hash(maskedFirstWordC);
|
||||
long nextMaskA = getDelimiterMask(nextWordA);
|
||||
long nextMaskB = getDelimiterMask(nextWordB);
|
||||
long nextMaskC = getDelimiterMask(nextWordC);
|
||||
|
||||
cursorA += isMaskZeroA * 8;
|
||||
cursorB += isMaskZeroB * 8;
|
||||
cursorC += isMaskZeroC * 8;
|
||||
boolean slowA = nextMaskA == 0;
|
||||
boolean slowB = nextMaskB == 0;
|
||||
boolean slowC = nextMaskC == 0;
|
||||
boolean slowSome = (slowA || slowB || slowC);
|
||||
|
||||
currentWordA = UNSAFE.getLong(cursorA);
|
||||
currentWordB = UNSAFE.getLong(cursorB);
|
||||
currentWordC = UNSAFE.getLong(cursorC);
|
||||
long extA = -isMaskZeroA;
|
||||
long extB = -isMaskZeroB;
|
||||
long extC = -isMaskZeroC;
|
||||
|
||||
maskA = getDelimiterMask(currentWordA);
|
||||
long maskedFirstWordA = (extA | firstWordMaskA) & currentWordA;
|
||||
long maskedFirstWordB = (extB | firstWordMaskB) & currentWordB;
|
||||
long maskedFirstWordC = (extC | firstWordMaskC) & currentWordC;
|
||||
|
||||
int hashA = hash(maskedFirstWordA);
|
||||
int hashB = hash(maskedFirstWordB);
|
||||
int hashC = hash(maskedFirstWordC);
|
||||
|
||||
currentWordA = nextWordA;
|
||||
currentWordB = nextWordB;
|
||||
currentWordC = nextWordC;
|
||||
|
||||
maskA = nextMaskA;
|
||||
maskB = nextMaskB;
|
||||
maskC = nextMaskC;
|
||||
if (slowSome) {
|
||||
while (maskA == 0) {
|
||||
cursorA += 8;
|
||||
currentWordA = UNSAFE.getLong(cursorA);
|
||||
maskA = getDelimiterMask(currentWordA);
|
||||
}
|
||||
maskB = getDelimiterMask(currentWordB);
|
||||
|
||||
while (maskB == 0) {
|
||||
cursorB += 8;
|
||||
currentWordB = UNSAFE.getLong(cursorB);
|
||||
maskB = getDelimiterMask(currentWordB);
|
||||
}
|
||||
maskC = getDelimiterMask(currentWordC);
|
||||
while (maskC == 0) {
|
||||
cursorC += 8;
|
||||
currentWordC = UNSAFE.getLong(cursorC);
|
||||
maskC = getDelimiterMask(currentWordC);
|
||||
}
|
||||
}
|
||||
|
||||
final int delimiterByteA = Long.numberOfTrailingZeros(maskA);
|
||||
final int delimiterByteB = Long.numberOfTrailingZeros(maskB);
|
||||
@ -458,40 +498,57 @@ public class CalculateAverage_jerrinot {
|
||||
long digitStartA = semicolonA + 1;
|
||||
long digitStartB = semicolonB + 1;
|
||||
long digitStartC = semicolonC + 1;
|
||||
|
||||
long temperatureWordA = UNSAFE.getLong(digitStartA);
|
||||
long temperatureWordB = UNSAFE.getLong(digitStartB);
|
||||
long temperatureWordC = UNSAFE.getLong(digitStartC);
|
||||
|
||||
final long maskedWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8;
|
||||
final long maskedWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8;
|
||||
final long maskedWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8;
|
||||
long lastWordMaskA = ((maskA - 1) ^ maskA) >>> 8;
|
||||
long lastWordMaskB = ((maskB - 1) ^ maskB) >>> 8;
|
||||
long lastWordMaskC = ((maskC - 1) ^ maskC) >>> 8;
|
||||
|
||||
long lenA = semicolonA - startA;
|
||||
long lenB = semicolonB - startB;
|
||||
long lenC = semicolonC - startC;
|
||||
final long maskedLastWordA = currentWordA & lastWordMaskA;
|
||||
final long maskedLastWordB = currentWordB & lastWordMaskB;
|
||||
final long maskedLastWordC = currentWordC & lastWordMaskC;
|
||||
|
||||
int lenA = (int) (semicolonA - startA);
|
||||
int lenB = (int) (semicolonB - startB);
|
||||
int lenC = (int) (semicolonC - startC);
|
||||
|
||||
int mapIndexA = hashA & MAP_MASK;
|
||||
int mapIndexB = hashB & MAP_MASK;
|
||||
int mapIndexC = hashC & MAP_MASK;
|
||||
|
||||
long baseEntryPtrA;
|
||||
if (lenA > 15) {
|
||||
baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, (int) hashA, maskedWordA);
|
||||
}
|
||||
else {
|
||||
baseEntryPtrA = getOrCreateEntryBaseOffsetFast(lenA, (int) hashA, maskedWordA, maskedFirstWordA);
|
||||
}
|
||||
|
||||
long baseEntryPtrB;
|
||||
if (lenB > 15) {
|
||||
baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, (int) hashB, maskedWordB);
|
||||
long baseEntryPtrC;
|
||||
|
||||
if (slowSome) {
|
||||
if (slowA) {
|
||||
baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, hashA, maskedLastWordA);
|
||||
}
|
||||
else {
|
||||
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(lenB, (int) hashB, maskedWordB, maskedFirstWordB);
|
||||
baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA);
|
||||
}
|
||||
|
||||
long baseEntryPtrC;
|
||||
if (lenC > 15) {
|
||||
baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, (int) hashC, maskedWordC);
|
||||
if (slowB) {
|
||||
baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, hashB, maskedLastWordB);
|
||||
}
|
||||
else {
|
||||
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(lenC, (int) hashC, maskedWordC, maskedFirstWordC);
|
||||
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB);
|
||||
}
|
||||
|
||||
if (slowC) {
|
||||
baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, hashC, maskedLastWordC);
|
||||
}
|
||||
else {
|
||||
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC);
|
||||
}
|
||||
}
|
||||
else {
|
||||
baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA);
|
||||
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB);
|
||||
baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC);
|
||||
}
|
||||
|
||||
cursorA = parseAndStoreTemperature(digitStartA, baseEntryPtrA, temperatureWordA);
|
||||
@ -502,21 +559,24 @@ public class CalculateAverage_jerrinot {
|
||||
// System.out.println("Longest chain: " + longestChain);
|
||||
}
|
||||
|
||||
private long getOrCreateEntryBaseOffsetFast(long lenLong, int hash, long maskedLastWord, long maskedFirstWord) {
|
||||
int lenA = (int) lenLong;
|
||||
long mapIndexA = hash & MAP_MASK;
|
||||
private long getOrCreateEntryBaseOffsetFast(int mapIndexA, int lenA, long maskedLastWord, long maskedFirstWord) {
|
||||
for (;;) {
|
||||
long basePtr = mapIndexA * FAST_MAP_ENTRY_SIZE_BYTES + fastMap;
|
||||
long lenPtr = basePtr + MAP_LEN_OFFSET;
|
||||
int len = UNSAFE.getInt(lenPtr);
|
||||
if (len == lenA) {
|
||||
long namePart1 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART1);
|
||||
long namePart2 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART2);
|
||||
if (namePart1 == maskedFirstWord && namePart2 == maskedLastWord) {
|
||||
return basePtr;
|
||||
}
|
||||
long lenPtr = basePtr + MAP_LEN_OFFSET;
|
||||
int len = UNSAFE.getInt(lenPtr);
|
||||
if (len == 0) {
|
||||
return newEntryFast(lenA, maskedLastWord, maskedFirstWord, lenPtr, basePtr);
|
||||
}
|
||||
else if (len == 0) {
|
||||
mapIndexA = ++mapIndexA & MAP_MASK;
|
||||
}
|
||||
}
|
||||
|
||||
private static long newEntryFast(int lenA, long maskedLastWord, long maskedFirstWord, long lenPtr, long basePtr) {
|
||||
UNSAFE.putInt(lenPtr, lenA);
|
||||
// todo: this could be a single putLong()
|
||||
UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE);
|
||||
@ -525,13 +585,9 @@ public class CalculateAverage_jerrinot {
|
||||
UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART2, maskedLastWord);
|
||||
return basePtr;
|
||||
}
|
||||
mapIndexA = ++mapIndexA & MAP_MASK;
|
||||
}
|
||||
}
|
||||
|
||||
private long getOrCreateEntryBaseOffsetSlow(long lenLong, long startPtr, int hash, long maskedLastWord) {
|
||||
long fullLen = lenLong & ~7L;
|
||||
int lenA = (int) lenLong;
|
||||
private long getOrCreateEntryBaseOffsetSlow(int lenA, long startPtr, int hash, long maskedLastWord) {
|
||||
long fullLen = lenA & ~7L;
|
||||
long mapIndexA = hash & MAP_MASK;
|
||||
for (;;) {
|
||||
long basePtr = mapIndexA * SLOW_MAP_ENTRY_SIZE_BYTES + slowMap;
|
||||
@ -550,7 +606,7 @@ public class CalculateAverage_jerrinot {
|
||||
UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE);
|
||||
UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE);
|
||||
UNSAFE.copyMemory(startPtr, slowMapNamesPtr, lenA);
|
||||
long alignedLen = (lenLong & ~7L) + 8;
|
||||
long alignedLen = (lenA & ~7L) + 8;
|
||||
slowMapNamesPtr += alignedLen;
|
||||
return basePtr;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user