Native image + a few smaller optimisations (#564)
* Inline parsing name and station to avoid constantly updating the offset field (-100ms) * Remove Worker class, inline the logic into lambda * Accumulate results in an int matrix instead of using result row (-50ms) * Use native image
This commit is contained in:
parent
ba793e88cd
commit
6c0949969a
@ -21,4 +21,15 @@ JAVA_OPTS="--enable-preview -XX:+UseTransparentHugePages"
|
|||||||
# see https://stackoverflow.com/questions/58087596/why-are-repeated-memory-allocations-observed-to-be-slower-using-epsilon-vs-g1
|
# see https://stackoverflow.com/questions/58087596/why-are-repeated-memory-allocations-observed-to-be-slower-using-epsilon-vs-g1
|
||||||
JAVA_OPTS="$JAVA_OPTS -XX:+UnlockExperimentalVMOptions -XX:-EnableJVMCI -XX:+UseEpsilonGC -Xmx1G -Xms1G -XX:+AlwaysPreTouch"
|
JAVA_OPTS="$JAVA_OPTS -XX:+UnlockExperimentalVMOptions -XX:-EnableJVMCI -XX:+UseEpsilonGC -Xmx1G -Xms1G -XX:+AlwaysPreTouch"
|
||||||
|
|
||||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_roman_r_m
|
if [ -f target/CalculateAverage_roman_r_m_image ]; then
|
||||||
|
echo "Picking up existing native image 'target/CalculateAverage_roman_r_m_image', delete the file to select JVM mode." 1>&2
|
||||||
|
target/CalculateAverage_roman_r_m_image
|
||||||
|
else
|
||||||
|
JAVA_OPTS="--enable-preview -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA"
|
||||||
|
if [[ ! "$(uname -s)" = "Darwin" ]]; then
|
||||||
|
# On OS/X, my machine, this errors:
|
||||||
|
JAVA_OPTS="$JAVA_OPTS -XX:+UseTransparentHugePages"
|
||||||
|
fi
|
||||||
|
echo "Choosing to run the app in JVM mode as no native image was found, use additional_build_step_roman_r_m.sh to generate." 1>&2
|
||||||
|
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_roman_r_m
|
||||||
|
fi
|
||||||
|
@ -17,3 +17,12 @@
|
|||||||
|
|
||||||
source "$HOME/.sdkman/bin/sdkman-init.sh"
|
source "$HOME/.sdkman/bin/sdkman-init.sh"
|
||||||
sdk use java 21.0.1-graal 1>&2
|
sdk use java 21.0.1-graal 1>&2
|
||||||
|
|
||||||
|
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
|
||||||
|
if [ ! -f target/CalculateAverage_roman_r_m_image ]; then
|
||||||
|
|
||||||
|
JAVA_OPTS="--enable-preview -dsa"
|
||||||
|
NATIVE_IMAGE_OPTS="--initialize-at-build-time=dev.morling.onebrc.CalculateAverage_roman_r_m --gc=epsilon -Ob -O3 -march=native --strict-image-heap $JAVA_OPTS"
|
||||||
|
|
||||||
|
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_roman_r_m_image dev.morling.onebrc.CalculateAverage_roman_r_m
|
||||||
|
fi
|
@ -64,119 +64,6 @@ public class CalculateAverage_roman_r_m {
|
|||||||
return start + Long.numberOfTrailingZeros(i) / 8;
|
return start + Long.numberOfTrailingZeros(i) / 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
static class Worker {
|
|
||||||
private final MemorySegment ms;
|
|
||||||
private final long end;
|
|
||||||
private long offset;
|
|
||||||
|
|
||||||
public Worker(FileChannel channel, long start, long end) {
|
|
||||||
try {
|
|
||||||
this.ms = channel.map(FileChannel.MapMode.READ_ONLY, start, end - start, Arena.ofConfined());
|
|
||||||
this.offset = ms.address();
|
|
||||||
this.end = ms.address() + end - start;
|
|
||||||
}
|
|
||||||
catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void parseName(ByteString station) {
|
|
||||||
long start = offset;
|
|
||||||
long next = UNSAFE.getLong(offset);
|
|
||||||
long pattern = applyPattern(next, SEMICOLON_MASK);
|
|
||||||
int bytes;
|
|
||||||
if (pattern != 0) {
|
|
||||||
bytes = Long.numberOfTrailingZeros(pattern) / 8;
|
|
||||||
offset += bytes;
|
|
||||||
long h = Long.reverseBytes(next) >>> (8 * (8 - bytes));
|
|
||||||
station.hash = (int) (h ^ (h >>> 32));
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
long h = next;
|
|
||||||
station.hash = (int) (h ^ (h >>> 32));
|
|
||||||
while (pattern == 0) {
|
|
||||||
offset += 8;
|
|
||||||
next = UNSAFE.getLong(offset);
|
|
||||||
pattern = applyPattern(next, SEMICOLON_MASK);
|
|
||||||
}
|
|
||||||
bytes = Long.numberOfTrailingZeros(pattern) / 8;
|
|
||||||
offset += bytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
int len = (int) (offset - start);
|
|
||||||
station.offset = start;
|
|
||||||
station.len = len;
|
|
||||||
station.tail = next & ((1L << (8 * bytes)) - 1);
|
|
||||||
|
|
||||||
offset++;
|
|
||||||
}
|
|
||||||
|
|
||||||
int parseNumberFast() {
|
|
||||||
long encodedVal = UNSAFE.getLong(offset);
|
|
||||||
|
|
||||||
int neg = 1 - Integer.bitCount((int) (encodedVal & 0x10));
|
|
||||||
encodedVal >>>= 8 * neg;
|
|
||||||
|
|
||||||
var len = applyPattern(encodedVal, DOT_MASK);
|
|
||||||
len = Long.numberOfTrailingZeros(len) / 8;
|
|
||||||
|
|
||||||
encodedVal ^= broadcast((byte) 0x30);
|
|
||||||
|
|
||||||
int intPart = (int) (encodedVal & ((1 << (8 * len)) - 1));
|
|
||||||
intPart <<= 8 * (2 - len);
|
|
||||||
intPart *= (100 * 256 + 10);
|
|
||||||
intPart = (intPart & 0x3FF80) >>> 8;
|
|
||||||
|
|
||||||
int frac = (int) ((encodedVal >>> (8 * (len + 1))) & 0xFF);
|
|
||||||
|
|
||||||
offset += neg + len + 3; // 1 for . + 1 for fractional part + 1 for new line char
|
|
||||||
int sign = 1 - 2 * neg;
|
|
||||||
int val = intPart + frac;
|
|
||||||
return sign * val;
|
|
||||||
}
|
|
||||||
|
|
||||||
int parseNumberSlow() {
|
|
||||||
int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10);
|
|
||||||
offset += neg;
|
|
||||||
|
|
||||||
int val = UNSAFE.getByte(offset++) - '0';
|
|
||||||
byte b;
|
|
||||||
while ((b = UNSAFE.getByte(offset++)) != '.') {
|
|
||||||
val = val * 10 + (b - '0');
|
|
||||||
}
|
|
||||||
b = UNSAFE.getByte(offset);
|
|
||||||
val = val * 10 + (b - '0');
|
|
||||||
offset += 2;
|
|
||||||
val *= 1 - 2 * neg;
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
int parseNumber() {
|
|
||||||
if (end - offset >= 8) {
|
|
||||||
return parseNumberFast();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
return parseNumberSlow();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public TreeMap<String, ResultRow> run() {
|
|
||||||
var resultStore = new ResultStore();
|
|
||||||
var station = new ByteString(ms);
|
|
||||||
|
|
||||||
while (offset < end) {
|
|
||||||
parseName(station);
|
|
||||||
long val = parseNumber();
|
|
||||||
var a = resultStore.get(station);
|
|
||||||
a.min = Math.min(a.min, val);
|
|
||||||
a.max = Math.max(a.max, val);
|
|
||||||
a.sum += val;
|
|
||||||
a.count++;
|
|
||||||
}
|
|
||||||
return resultStore.toMap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
Field f = Unsafe.class.getDeclaredField("theUnsafe");
|
Field f = Unsafe.class.getDeclaredField("theUnsafe");
|
||||||
f.setAccessible(true);
|
f.setAccessible(true);
|
||||||
@ -200,12 +87,97 @@ public class CalculateAverage_roman_r_m {
|
|||||||
var result = IntStream.range(0, numThreads)
|
var result = IntStream.range(0, numThreads)
|
||||||
.parallel()
|
.parallel()
|
||||||
.mapToObj(i -> {
|
.mapToObj(i -> {
|
||||||
long start = i == 0 ? 0 : bounds[i - 1] + 1;
|
try {
|
||||||
long end = bounds[i];
|
long segmentStart = i == 0 ? 0 : bounds[i - 1] + 1;
|
||||||
Worker worker = new Worker(channel, start, end);
|
long segmentEnd = bounds[i];
|
||||||
var res = worker.run();
|
var segment = channel.map(FileChannel.MapMode.READ_ONLY, segmentStart, segmentEnd - segmentStart, Arena.ofConfined());
|
||||||
worker.ms.unload();
|
|
||||||
return res;
|
var resultStore = new ResultStore();
|
||||||
|
var station = new ByteString(segment);
|
||||||
|
long offset = segment.address();
|
||||||
|
long end = offset + segment.byteSize();
|
||||||
|
while (offset < end) {
|
||||||
|
// parsing station name
|
||||||
|
long start = offset;
|
||||||
|
long next = UNSAFE.getLong(offset);
|
||||||
|
long pattern = applyPattern(next, SEMICOLON_MASK);
|
||||||
|
int bytes;
|
||||||
|
if (pattern != 0) {
|
||||||
|
bytes = Long.numberOfTrailingZeros(pattern) / 8;
|
||||||
|
offset += bytes;
|
||||||
|
long h = Long.reverseBytes(next) >>> (8 * (8 - bytes));
|
||||||
|
station.hash = (int) (h ^ (h >>> 32));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
long h = next;
|
||||||
|
station.hash = (int) (h ^ (h >>> 32));
|
||||||
|
while (pattern == 0) {
|
||||||
|
offset += 8;
|
||||||
|
next = UNSAFE.getLong(offset);
|
||||||
|
pattern = applyPattern(next, SEMICOLON_MASK);
|
||||||
|
}
|
||||||
|
bytes = Long.numberOfTrailingZeros(pattern) / 8;
|
||||||
|
offset += bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
int len = (int) (offset - start);
|
||||||
|
station.offset = start;
|
||||||
|
station.len = len;
|
||||||
|
station.tail = next & ((1L << (8 * bytes)) - 1);
|
||||||
|
|
||||||
|
offset++;
|
||||||
|
|
||||||
|
// parsing temperature
|
||||||
|
// TODO next may contain temperature as well, maybe try using it if we know the full number is there
|
||||||
|
// 8 - bytes >= 5 -> bytes <= 3
|
||||||
|
long val;
|
||||||
|
if (end - offset >= 8) {
|
||||||
|
long encodedVal = UNSAFE.getLong(offset);
|
||||||
|
|
||||||
|
int neg = 1 - Integer.bitCount((int) (encodedVal & 0x10));
|
||||||
|
encodedVal >>>= 8 * neg;
|
||||||
|
|
||||||
|
long numLen = applyPattern(encodedVal, DOT_MASK);
|
||||||
|
numLen = Long.numberOfTrailingZeros(numLen) / 8;
|
||||||
|
|
||||||
|
encodedVal ^= broadcast((byte) 0x30);
|
||||||
|
|
||||||
|
int intPart = (int) (encodedVal & ((1 << (8 * numLen)) - 1));
|
||||||
|
intPart <<= 8 * (2 - numLen);
|
||||||
|
intPart *= (100 * 256 + 10);
|
||||||
|
intPart = (intPart & 0x3FF80) >>> 8;
|
||||||
|
|
||||||
|
int frac = (int) ((encodedVal >>> (8 * (numLen + 1))) & 0xFF);
|
||||||
|
|
||||||
|
offset += neg + numLen + 3; // 1 for . + 1 for fractional part + 1 for new line char
|
||||||
|
int sign = 1 - 2 * neg;
|
||||||
|
val = sign * (intPart + frac);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10);
|
||||||
|
offset += neg;
|
||||||
|
|
||||||
|
val = UNSAFE.getByte(offset++) - '0';
|
||||||
|
byte b;
|
||||||
|
while ((b = UNSAFE.getByte(offset++)) != '.') {
|
||||||
|
val = val * 10 + (b - '0');
|
||||||
|
}
|
||||||
|
b = UNSAFE.getByte(offset);
|
||||||
|
val = val * 10 + (b - '0');
|
||||||
|
offset += 2;
|
||||||
|
val *= 1 - (2L * neg);
|
||||||
|
}
|
||||||
|
|
||||||
|
resultStore.update(station, (int) val);
|
||||||
|
}
|
||||||
|
|
||||||
|
segment.unload();
|
||||||
|
|
||||||
|
return resultStore.toMap();
|
||||||
|
}
|
||||||
|
catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}).reduce((m1, m2) -> {
|
}).reduce((m1, m2) -> {
|
||||||
m2.forEach((k, v) -> m1.merge(k, v, ResultRow::merge));
|
m2.forEach((k, v) -> m1.merge(k, v, ResultRow::merge));
|
||||||
return m1;
|
return m1;
|
||||||
@ -275,10 +247,17 @@ public class CalculateAverage_roman_r_m {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static final class ResultRow {
|
private static final class ResultRow {
|
||||||
long min = 1000;
|
long min;
|
||||||
long sum = 0;
|
long sum;
|
||||||
long max = -1000;
|
long max;
|
||||||
int count = 0;
|
int count;
|
||||||
|
|
||||||
|
public ResultRow(int[] values) {
|
||||||
|
min = values[0];
|
||||||
|
max = values[1];
|
||||||
|
sum = values[2];
|
||||||
|
count = values[3];
|
||||||
|
}
|
||||||
|
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return round(min / 10.0) + "/" + round(sum / 10.0 / count) + "/" + round(max / 10.0);
|
return round(min / 10.0) + "/" + round(sum / 10.0 / count) + "/" + round(max / 10.0);
|
||||||
@ -300,9 +279,9 @@ public class CalculateAverage_roman_r_m {
|
|||||||
static class ResultStore {
|
static class ResultStore {
|
||||||
private static final int SIZE = 16384;
|
private static final int SIZE = 16384;
|
||||||
private final ByteString[] keys = new ByteString[SIZE];
|
private final ByteString[] keys = new ByteString[SIZE];
|
||||||
private final ResultRow[] values = new ResultRow[SIZE];
|
private final int[][] values = new int[SIZE][];
|
||||||
|
|
||||||
ResultRow get(ByteString s) {
|
void update(ByteString s, int value) {
|
||||||
int h = s.hashCode();
|
int h = s.hashCode();
|
||||||
int idx = (SIZE - 1) & h;
|
int idx = (SIZE - 1) & h;
|
||||||
|
|
||||||
@ -311,18 +290,20 @@ public class CalculateAverage_roman_r_m {
|
|||||||
i++;
|
i++;
|
||||||
idx = (idx + i * i) % SIZE;
|
idx = (idx + i * i) % SIZE;
|
||||||
}
|
}
|
||||||
ResultRow result;
|
|
||||||
if (keys[idx] == null) {
|
if (keys[idx] == null) {
|
||||||
keys[idx] = s.copy();
|
keys[idx] = s.copy();
|
||||||
result = new ResultRow();
|
values[idx] = new int[4];
|
||||||
values[idx] = result;
|
values[idx][0] = value;
|
||||||
|
values[idx][1] = value;
|
||||||
|
values[idx][2] = value;
|
||||||
|
values[idx][3] = 1;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
result = values[idx];
|
values[idx][0] = Math.min(values[idx][0], value);
|
||||||
// TODO see it it makes any difference
|
values[idx][1] = Math.max(values[idx][1], value);
|
||||||
// keys[idx].offset = s.offset;
|
values[idx][2] += value;
|
||||||
|
values[idx][3] += 1;
|
||||||
}
|
}
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TreeMap<String, ResultRow> toMap() {
|
TreeMap<String, ResultRow> toMap() {
|
||||||
@ -330,7 +311,7 @@ public class CalculateAverage_roman_r_m {
|
|||||||
var result = new TreeMap<String, ResultRow>();
|
var result = new TreeMap<String, ResultRow>();
|
||||||
for (int i = 0; i < SIZE; i++) {
|
for (int i = 0; i < SIZE; i++) {
|
||||||
if (keys[i] != null) {
|
if (keys[i] != null) {
|
||||||
result.put(keys[i].asString(buf), values[i]);
|
result.put(keys[i].asString(buf), new ResultRow(values[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
Loading…
Reference in New Issue
Block a user