Native image + a few smaller optimisations (#564)

* Inline parsing name and station to avoid constantly updating the offset field (-100ms)

* Remove Worker class, inline the logic into lambda

* Accumulate results in an int matrix instead of using result row (-50ms)

* Use native image
This commit is contained in:
Roman Musin 2024-01-23 19:19:07 +00:00 committed by GitHub
parent ba793e88cd
commit 6c0949969a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 135 additions and 134 deletions

View File

@ -21,4 +21,15 @@ JAVA_OPTS="--enable-preview -XX:+UseTransparentHugePages"
# see https://stackoverflow.com/questions/58087596/why-are-repeated-memory-allocations-observed-to-be-slower-using-epsilon-vs-g1 # see https://stackoverflow.com/questions/58087596/why-are-repeated-memory-allocations-observed-to-be-slower-using-epsilon-vs-g1
JAVA_OPTS="$JAVA_OPTS -XX:+UnlockExperimentalVMOptions -XX:-EnableJVMCI -XX:+UseEpsilonGC -Xmx1G -Xms1G -XX:+AlwaysPreTouch" JAVA_OPTS="$JAVA_OPTS -XX:+UnlockExperimentalVMOptions -XX:-EnableJVMCI -XX:+UseEpsilonGC -Xmx1G -Xms1G -XX:+AlwaysPreTouch"
if [ -f target/CalculateAverage_roman_r_m_image ]; then
echo "Picking up existing native image 'target/CalculateAverage_roman_r_m_image', delete the file to select JVM mode." 1>&2
target/CalculateAverage_roman_r_m_image
else
JAVA_OPTS="--enable-preview -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA"
if [[ ! "$(uname -s)" = "Darwin" ]]; then
# On OS/X, my machine, this errors:
JAVA_OPTS="$JAVA_OPTS -XX:+UseTransparentHugePages"
fi
echo "Choosing to run the app in JVM mode as no native image was found, use additional_build_step_roman_r_m.sh to generate." 1>&2
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_roman_r_m java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_roman_r_m
fi

View File

@ -17,3 +17,12 @@
source "$HOME/.sdkman/bin/sdkman-init.sh" source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.1-graal 1>&2 sdk use java 21.0.1-graal 1>&2
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
if [ ! -f target/CalculateAverage_roman_r_m_image ]; then
JAVA_OPTS="--enable-preview -dsa"
NATIVE_IMAGE_OPTS="--initialize-at-build-time=dev.morling.onebrc.CalculateAverage_roman_r_m --gc=epsilon -Ob -O3 -march=native --strict-image-heap $JAVA_OPTS"
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_roman_r_m_image dev.morling.onebrc.CalculateAverage_roman_r_m
fi

View File

@ -64,23 +64,40 @@ public class CalculateAverage_roman_r_m {
return start + Long.numberOfTrailingZeros(i) / 8; return start + Long.numberOfTrailingZeros(i) / 8;
} }
static class Worker { public static void main(String[] args) throws Exception {
private final MemorySegment ms; Field f = Unsafe.class.getDeclaredField("theUnsafe");
private final long end; f.setAccessible(true);
private long offset; UNSAFE = (Unsafe) f.get(null);
public Worker(FileChannel channel, long start, long end) { long fileSize = new File(FILE).length();
var channel = FileChannel.open(Paths.get(FILE));
MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofConfined());
int numThreads = fileSize > Integer.MAX_VALUE ? Runtime.getRuntime().availableProcessors() : 1;
long chunk = fileSize / numThreads;
var bounds = IntStream.range(0, numThreads).mapToLong(i -> {
boolean lastChunk = i == numThreads - 1;
return lastChunk ? fileSize : nextNewline((i + 1) * chunk, ms);
}).toArray();
ms.unload();
var result = IntStream.range(0, numThreads)
.parallel()
.mapToObj(i -> {
try { try {
this.ms = channel.map(FileChannel.MapMode.READ_ONLY, start, end - start, Arena.ofConfined()); long segmentStart = i == 0 ? 0 : bounds[i - 1] + 1;
this.offset = ms.address(); long segmentEnd = bounds[i];
this.end = ms.address() + end - start; var segment = channel.map(FileChannel.MapMode.READ_ONLY, segmentStart, segmentEnd - segmentStart, Arena.ofConfined());
}
catch (Exception e) {
throw new RuntimeException(e);
}
}
private void parseName(ByteString station) { var resultStore = new ResultStore();
var station = new ByteString(segment);
long offset = segment.address();
long end = offset + segment.byteSize();
while (offset < end) {
// parsing station name
long start = offset; long start = offset;
long next = UNSAFE.getLong(offset); long next = UNSAFE.getLong(offset);
long pattern = applyPattern(next, SEMICOLON_MASK); long pattern = applyPattern(next, SEMICOLON_MASK);
@ -109,37 +126,38 @@ public class CalculateAverage_roman_r_m {
station.tail = next & ((1L << (8 * bytes)) - 1); station.tail = next & ((1L << (8 * bytes)) - 1);
offset++; offset++;
}
int parseNumberFast() { // parsing temperature
// TODO next may contain temperature as well, maybe try using it if we know the full number is there
// 8 - bytes >= 5 -> bytes <= 3
long val;
if (end - offset >= 8) {
long encodedVal = UNSAFE.getLong(offset); long encodedVal = UNSAFE.getLong(offset);
int neg = 1 - Integer.bitCount((int) (encodedVal & 0x10)); int neg = 1 - Integer.bitCount((int) (encodedVal & 0x10));
encodedVal >>>= 8 * neg; encodedVal >>>= 8 * neg;
var len = applyPattern(encodedVal, DOT_MASK); long numLen = applyPattern(encodedVal, DOT_MASK);
len = Long.numberOfTrailingZeros(len) / 8; numLen = Long.numberOfTrailingZeros(numLen) / 8;
encodedVal ^= broadcast((byte) 0x30); encodedVal ^= broadcast((byte) 0x30);
int intPart = (int) (encodedVal & ((1 << (8 * len)) - 1)); int intPart = (int) (encodedVal & ((1 << (8 * numLen)) - 1));
intPart <<= 8 * (2 - len); intPart <<= 8 * (2 - numLen);
intPart *= (100 * 256 + 10); intPart *= (100 * 256 + 10);
intPart = (intPart & 0x3FF80) >>> 8; intPart = (intPart & 0x3FF80) >>> 8;
int frac = (int) ((encodedVal >>> (8 * (len + 1))) & 0xFF); int frac = (int) ((encodedVal >>> (8 * (numLen + 1))) & 0xFF);
offset += neg + len + 3; // 1 for . + 1 for fractional part + 1 for new line char offset += neg + numLen + 3; // 1 for . + 1 for fractional part + 1 for new line char
int sign = 1 - 2 * neg; int sign = 1 - 2 * neg;
int val = intPart + frac; val = sign * (intPart + frac);
return sign * val;
} }
else {
int parseNumberSlow() {
int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10); int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10);
offset += neg; offset += neg;
int val = UNSAFE.getByte(offset++) - '0'; val = UNSAFE.getByte(offset++) - '0';
byte b; byte b;
while ((b = UNSAFE.getByte(offset++)) != '.') { while ((b = UNSAFE.getByte(offset++)) != '.') {
val = val * 10 + (b - '0'); val = val * 10 + (b - '0');
@ -147,65 +165,19 @@ public class CalculateAverage_roman_r_m {
b = UNSAFE.getByte(offset); b = UNSAFE.getByte(offset);
val = val * 10 + (b - '0'); val = val * 10 + (b - '0');
offset += 2; offset += 2;
val *= 1 - 2 * neg; val *= 1 - (2L * neg);
return val;
} }
int parseNumber() { resultStore.update(station, (int) val);
if (end - offset >= 8) {
return parseNumberFast();
}
else {
return parseNumberSlow();
}
} }
public TreeMap<String, ResultRow> run() { segment.unload();
var resultStore = new ResultStore();
var station = new ByteString(ms);
while (offset < end) {
parseName(station);
long val = parseNumber();
var a = resultStore.get(station);
a.min = Math.min(a.min, val);
a.max = Math.max(a.max, val);
a.sum += val;
a.count++;
}
return resultStore.toMap(); return resultStore.toMap();
} }
catch (Exception e) {
throw new RuntimeException(e);
} }
public static void main(String[] args) throws Exception {
Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
UNSAFE = (Unsafe) f.get(null);
long fileSize = new File(FILE).length();
var channel = FileChannel.open(Paths.get(FILE));
MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofConfined());
int numThreads = fileSize > Integer.MAX_VALUE ? Runtime.getRuntime().availableProcessors() : 1;
long chunk = fileSize / numThreads;
var bounds = IntStream.range(0, numThreads).mapToLong(i -> {
boolean lastChunk = i == numThreads - 1;
return lastChunk ? fileSize : nextNewline((i + 1) * chunk, ms);
}).toArray();
ms.unload();
var result = IntStream.range(0, numThreads)
.parallel()
.mapToObj(i -> {
long start = i == 0 ? 0 : bounds[i - 1] + 1;
long end = bounds[i];
Worker worker = new Worker(channel, start, end);
var res = worker.run();
worker.ms.unload();
return res;
}).reduce((m1, m2) -> { }).reduce((m1, m2) -> {
m2.forEach((k, v) -> m1.merge(k, v, ResultRow::merge)); m2.forEach((k, v) -> m1.merge(k, v, ResultRow::merge));
return m1; return m1;
@ -275,10 +247,17 @@ public class CalculateAverage_roman_r_m {
} }
private static final class ResultRow { private static final class ResultRow {
long min = 1000; long min;
long sum = 0; long sum;
long max = -1000; long max;
int count = 0; int count;
public ResultRow(int[] values) {
min = values[0];
max = values[1];
sum = values[2];
count = values[3];
}
public String toString() { public String toString() {
return round(min / 10.0) + "/" + round(sum / 10.0 / count) + "/" + round(max / 10.0); return round(min / 10.0) + "/" + round(sum / 10.0 / count) + "/" + round(max / 10.0);
@ -300,9 +279,9 @@ public class CalculateAverage_roman_r_m {
static class ResultStore { static class ResultStore {
private static final int SIZE = 16384; private static final int SIZE = 16384;
private final ByteString[] keys = new ByteString[SIZE]; private final ByteString[] keys = new ByteString[SIZE];
private final ResultRow[] values = new ResultRow[SIZE]; private final int[][] values = new int[SIZE][];
ResultRow get(ByteString s) { void update(ByteString s, int value) {
int h = s.hashCode(); int h = s.hashCode();
int idx = (SIZE - 1) & h; int idx = (SIZE - 1) & h;
@ -311,18 +290,20 @@ public class CalculateAverage_roman_r_m {
i++; i++;
idx = (idx + i * i) % SIZE; idx = (idx + i * i) % SIZE;
} }
ResultRow result;
if (keys[idx] == null) { if (keys[idx] == null) {
keys[idx] = s.copy(); keys[idx] = s.copy();
result = new ResultRow(); values[idx] = new int[4];
values[idx] = result; values[idx][0] = value;
values[idx][1] = value;
values[idx][2] = value;
values[idx][3] = 1;
} }
else { else {
result = values[idx]; values[idx][0] = Math.min(values[idx][0], value);
// TODO see it it makes any difference values[idx][1] = Math.max(values[idx][1], value);
// keys[idx].offset = s.offset; values[idx][2] += value;
values[idx][3] += 1;
} }
return result;
} }
TreeMap<String, ResultRow> toMap() { TreeMap<String, ResultRow> toMap() {
@ -330,7 +311,7 @@ public class CalculateAverage_roman_r_m {
var result = new TreeMap<String, ResultRow>(); var result = new TreeMap<String, ResultRow>();
for (int i = 0; i < SIZE; i++) { for (int i = 0; i < SIZE; i++) {
if (keys[i] != null) { if (keys[i] != null) {
result.put(keys[i].asString(buf), values[i]); result.put(keys[i].asString(buf), new ResultRow(values[i]));
} }
} }
return result; return result;