diff --git a/calculate_average_armandino.sh b/calculate_average_armandino.sh index 6ac5c16..21a4f8c 100755 --- a/calculate_average_armandino.sh +++ b/calculate_average_armandino.sh @@ -15,6 +15,11 @@ # limitations under the License. # - -JAVA_OPTS="--enable-preview -da -dsa -Xms128m -Xmx128m -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -XX:+AlwaysPreTouch" -java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_armandino +if [ -f target/CalculateAverage_armandino_image ]; then + echo "Picking up existing native image 'target/CalculateAverage_armandino_image', delete the file to select JVM mode." 1>&2 + target/CalculateAverage_armandino_image +else + echo "Chosing to run the app in JVM mode as no native image was found, use prepare_armandino.sh to generate." 1>&2 + JAVA_OPTS="--enable-preview -da -dsa -Xms128m -Xmx128m -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -XX:+AlwaysPreTouch" + java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_armandino +fi diff --git a/prepare_armandino.sh b/prepare_armandino.sh new file mode 100755 index 0000000..19a71f9 --- /dev/null +++ b/prepare_armandino.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.2-graal 1>&2 + +# ./mvnw clean verify removes target/ and will re-trigger native image creation. +if [ ! -f target/CalculateAverage_armandino_image ]; then + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_armandino\$Scanner" + native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_armandino_image dev.morling.onebrc.CalculateAverage_armandino +fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java b/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java index dce3a33..d825e77 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java @@ -45,6 +45,7 @@ public class CalculateAverage_armandino { private static final byte DOT = 46; private static final byte MINUS = 45; private static final byte ZERO_DIGIT = 48; + private static final int PRIME = 1117; private static final Unsafe UNSAFE = getUnsafe(); public static void main(String[] args) throws Exception { @@ -78,7 +79,7 @@ public class CalculateAverage_armandino { byte b; while ((b = UNSAFE.getByte(i++)) != SEMICOLON) { - keyHash = 31 * keyHash + b; + keyHash = PRIME * keyHash + b; } final int keyLength = (int) (i - keyAddress - 1); @@ -114,13 +115,14 @@ public class CalculateAverage_armandino { stats.sum += measurement; stats.count++; } + return map; } } private static class Stats implements Comparable { private String key; - private final byte[] keyBytes; + private final long keyAddress; private final int keyLength; private final int keyHash; private int min = Integer.MAX_VALUE; @@ -129,17 +131,15 @@ public class CalculateAverage_armandino { private long sum; private Stats(long keyAddress, int keyLength, int keyHash) { + this.keyAddress = keyAddress; this.keyLength = keyLength; - this.keyBytes = new byte[keyLength]; this.keyHash = keyHash; - - for (int i = 0; i < keyLength; i++) { - keyBytes[i] = UNSAFE.getByte(keyAddress++); - } } String getKey() { if (key == null) { + var keyBytes = new byte[keyLength]; + UNSAFE.copyMemory(null, keyAddress, keyBytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, keyLength); key = new String(keyBytes, 0, keyLength, UTF_8); } return key; @@ -230,37 +230,6 @@ public class CalculateAverage_armandino { return Arrays.stream(table).filter(Objects::nonNull); } - private void resize() { - var copy = new SimpleMap(table.length * 2); - for (Stats s : table) { - if (s != null) { - final int pos = (copy.table.length - 1) & s.keyHash; - int i = pos; - - if (copy.table[i] == null) { - copy.table[i] = s; - continue; - } - - while (i < copy.table.length && copy.table[i] != null) { - i++; - } - if (i == copy.table.length) { - i = pos; - while (i >= 0 && copy.table[i] != null) { - i--; - } - } - if (i < 0) { - // shouldn't happen because put() is called after increasing size - throw new IllegalStateException("table is full"); - } - copy.table[i] = s; - } - } - table = copy.table; - } - Stats putStats(final int keyHash, final long keyAddress, final int keyLength) { final int pos = (table.length - 1) & keyHash; @@ -291,22 +260,49 @@ public class CalculateAverage_armandino { return putStats(keyHash, keyAddress, keyLength); } - private boolean keysEqual(Stats stats, long keyAddress, final int keyLength) { - if (stats.keyLength != keyLength) { - return false; - } - for (int i = 0; i < keyLength; i++) { - if (stats.keyBytes[i] != UNSAFE.getByte(keyAddress++)) { - return false; - } - } - return true; - } - private static Stats createAt(Stats[] table, long keyAddress, int keyLength, int key, int i) { Stats stats = new Stats(keyAddress, keyLength, key); table[i] = stats; return stats; } + + private static boolean keysEqual(Stats stats, long keyAddress, final int keyLength) { + // credit: abeobk + long xsum = 0; + int n = keyLength & 0xF8; + for (int i = 0; i < n; i += 8) { + xsum |= (UNSAFE.getLong(stats.keyAddress + i) ^ UNSAFE.getLong(keyAddress + i)); + } + return xsum == 0; + } + + private void resize() { + var copy = new SimpleMap(table.length * 2); + for (Stats s : table) { + if (s != null) { + final int pos = (copy.table.length - 1) & s.keyHash; + int i = pos; + if (copy.table[i] == null) { + copy.table[i] = s; + continue; + } + while (i < copy.table.length && copy.table[i] != null) { + i++; + } + if (i == copy.table.length) { + i = pos; + while (i >= 0 && copy.table[i] != null) { + i--; + } + } + if (i < 0) { + // if we reach here it's a bug! + throw new IllegalStateException("table is full"); + } + copy.table[i] = s; + } + } + table = copy.table; + } } }