mtopolnik submission 3 (#637)
* calculate_average_mtopolnik * short hash (just first 8 bytes of name) * Remove unneeded checks * Remove archiving classes * 2x larger hashtable * Add "set" to setters * Simplify parsing temperature, remove newline search * Reduce the size of the name slot * Store name length and use to detect collision * Reduce memory loads in parseTemperature * Use short for min/max * Extract constant for semicolon * Fix script header * Explicit bash shell in shebang * Inline usage of broadcast semicolon * Try vectorization * Remove vectorization * Go Unsafe * Use SWAR temperature parsing by merykitty * Inline some things * Remove commented-out MemorySegment usage * Inline namesMem.asSlice() invocation * Try out JVM JIT flags * Implement strcmp * Remove unused instance variables * Optimize hashing * Put station name into hashtable * Reorder method * Remove usage of MemorySegment.getUtf8String Replace with UNSAFE.copyMemory() and new String() * Fix hashing bug * Remove outdated comments * Fix informative constants * Use broadcastByte() more * Improve method naming * More hashing * Revert more hashing * Add commented-out code to hash 16 bytes * Slight cleanup * Align hashtable at cacheline boundary * Add Graal Native image * Revert Graal Native image This reverts commit d916a42326d89bd1a841bbbecfae185adb8679d7. * Simplify shell script (no SDK selection) * Move a constant, zero out hashtable on start * Better name comparison * Add prepare_mtopolnik.sh * Cleaner idiom in name comparison * AND instead of MOD for hashtable indexing * Improve word masking code * Fix formatting * Reduce memory loads * Remove endianness checks * Avoid hash == 0 problem * Fix subtle bug * MergeSort of parellel results * Touch up perf * Touch up perf * Remove -Xmx256m * Extract result printing method * Print allocation details on OOME * Single mmap * Use global allocation arena * Add commented-out Xmx64m XXMaxDirectMemorySize=1g * withinSafeZone * Update cursor earlier * Better assert * Fix bug in addrOfSemicolonSafe * Move declaration lower * Simplify code * Add rounding error test case * Fix DANGER_ZONE_LEN * Deoptimize parseTemperatureSimple() * Inline parseTemperatureAndAdvanceCursor() * Skip masking until the last load * Conditionally fetch name words * Cleanup * Use native image * Use supbrocess * Simpler code * Cleanup * Avoid extra condition on hot path
This commit is contained in:
		@@ -15,5 +15,11 @@
 | 
			
		||||
#  limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
java --enable-preview \
 | 
			
		||||
  --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_mtopolnik
 | 
			
		||||
if [ -f target/CalculateAverage_mtopolnik_image ]; then
 | 
			
		||||
    echo "Using native image 'target/CalculateAverage_mtopolnik_image'" 1>&2
 | 
			
		||||
    target/CalculateAverage_mtopolnik_image
 | 
			
		||||
else
 | 
			
		||||
    JAVA_OPTS="--enable-preview"
 | 
			
		||||
    echo "Native image not found, using JVM mode." 1>&2
 | 
			
		||||
    java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_mtopolnik
 | 
			
		||||
fi
 | 
			
		||||
 
 | 
			
		||||
@@ -16,4 +16,9 @@
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
source "$HOME/.sdkman/bin/sdkman-init.sh"
 | 
			
		||||
sdk use java 21.0.1-graal 1>&2
 | 
			
		||||
sdk use java 21.0.2-graal 1>&2
 | 
			
		||||
 | 
			
		||||
if [ ! -f target/CalculateAverage_mtopolnik_image ]; then
 | 
			
		||||
    NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -H:+UnlockExperimentalVMOptions -H:-GenLoopSafepoints -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_mtopolnik"
 | 
			
		||||
    native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_mtopolnik_image dev.morling.onebrc.CalculateAverage_mtopolnik
 | 
			
		||||
fi
 | 
			
		||||
 
 | 
			
		||||
@@ -29,18 +29,15 @@ import java.nio.charset.StandardCharsets;
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
 | 
			
		||||
import static java.lang.ProcessBuilder.Redirect.PIPE;
 | 
			
		||||
import static java.util.Arrays.asList;
 | 
			
		||||
 | 
			
		||||
public class CalculateAverage_mtopolnik {
 | 
			
		||||
    private static final Unsafe UNSAFE = unsafe();
 | 
			
		||||
    private static final int MAX_NAME_LEN = 100;
 | 
			
		||||
    private static final int STATS_TABLE_SIZE = 1 << 16;
 | 
			
		||||
    private static final int TABLE_INDEX_MASK = STATS_TABLE_SIZE - 1;
 | 
			
		||||
    private static final String MEASUREMENTS_TXT = "measurements.txt";
 | 
			
		||||
    private static final byte SEMICOLON = ';';
 | 
			
		||||
    private static final long BROADCAST_SEMICOLON = broadcastByte(SEMICOLON);
 | 
			
		||||
 | 
			
		||||
    // These two are just informative, I let the IDE calculate them for me
 | 
			
		||||
    private static final long NATIVE_MEM_PER_THREAD = StatsAccessor.SIZEOF * STATS_TABLE_SIZE;
 | 
			
		||||
    private static final long NATIVE_MEM_ON_8_THREADS = 8 * NATIVE_MEM_PER_THREAD;
 | 
			
		||||
 | 
			
		||||
    private static Unsafe unsafe() {
 | 
			
		||||
        try {
 | 
			
		||||
@@ -53,31 +50,23 @@ public class CalculateAverage_mtopolnik {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static class StationStats implements Comparable<StationStats> {
 | 
			
		||||
        String name;
 | 
			
		||||
        long sum;
 | 
			
		||||
        int count;
 | 
			
		||||
        int min;
 | 
			
		||||
        int max;
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public String toString() {
 | 
			
		||||
            return String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum / count) / 10.0, max / 10.0);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public boolean equals(Object that) {
 | 
			
		||||
            return that.getClass() == StationStats.class && ((StationStats) that).name.equals(this.name);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public int compareTo(StationStats that) {
 | 
			
		||||
            return name.compareTo(that.name);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static void main(String[] args) throws Exception {
 | 
			
		||||
        calculate();
 | 
			
		||||
        if (args.length >= 1 && args[0].equals("--worker")) {
 | 
			
		||||
            calculate();
 | 
			
		||||
            System.out.close();
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
        var curProcInfo = ProcessHandle.current().info();
 | 
			
		||||
        var cmdLine = new ArrayList<String>();
 | 
			
		||||
        cmdLine.add(curProcInfo.command().get());
 | 
			
		||||
        cmdLine.addAll(asList(curProcInfo.arguments().get()));
 | 
			
		||||
        cmdLine.add("--worker");
 | 
			
		||||
        var process = new ProcessBuilder()
 | 
			
		||||
                .command(cmdLine)
 | 
			
		||||
                .inheritIO().redirectOutput(PIPE)
 | 
			
		||||
                .start()
 | 
			
		||||
                .getInputStream().transferTo(System.out);
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static void calculate() throws Exception {
 | 
			
		||||
@@ -113,7 +102,6 @@ public class CalculateAverage_mtopolnik {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class ChunkProcessor implements Runnable {
 | 
			
		||||
        private static final long NAMEBUF_SIZE = 2 * Long.BYTES;
 | 
			
		||||
        private static final int CACHELINE_SIZE = 64;
 | 
			
		||||
 | 
			
		||||
        private final long inputBase;
 | 
			
		||||
@@ -122,8 +110,6 @@ public class CalculateAverage_mtopolnik {
 | 
			
		||||
        private final int myIndex;
 | 
			
		||||
 | 
			
		||||
        private StatsAccessor stats;
 | 
			
		||||
        private long nameBufBase;
 | 
			
		||||
        private long cursor;
 | 
			
		||||
 | 
			
		||||
        ChunkProcessor(long chunkStart, long chunkLimit, StationStats[][] results, int myIndex) {
 | 
			
		||||
            this.inputBase = chunkStart;
 | 
			
		||||
@@ -138,16 +124,12 @@ public class CalculateAverage_mtopolnik {
 | 
			
		||||
                long totalAllocated = 0;
 | 
			
		||||
                String threadName = Thread.currentThread().getName();
 | 
			
		||||
                long statsByteSize = STATS_TABLE_SIZE * StatsAccessor.SIZEOF;
 | 
			
		||||
                var diagnosticString = String.format("Thread %s needs %,d bytes, managed to allocate before OOM: ",
 | 
			
		||||
                        threadName, statsByteSize + NAMEBUF_SIZE);
 | 
			
		||||
                var diagnosticString = String.format("Thread %s needs %,d bytes", threadName, statsByteSize);
 | 
			
		||||
                try {
 | 
			
		||||
                    stats = new StatsAccessor(confinedArena.allocate(statsByteSize, CACHELINE_SIZE));
 | 
			
		||||
                    totalAllocated = statsByteSize;
 | 
			
		||||
                    nameBufBase = confinedArena.allocate(NAMEBUF_SIZE).address();
 | 
			
		||||
                }
 | 
			
		||||
                catch (OutOfMemoryError e) {
 | 
			
		||||
                    System.err.print(diagnosticString);
 | 
			
		||||
                    System.err.println(totalAllocated);
 | 
			
		||||
                    throw e;
 | 
			
		||||
                }
 | 
			
		||||
                processChunk();
 | 
			
		||||
@@ -155,227 +137,110 @@ public class CalculateAverage_mtopolnik {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private static final int MAX_TEMPERATURE_LEN = 5;
 | 
			
		||||
        private static final int MAX_ROW_LEN = MAX_NAME_LEN + 1 + MAX_TEMPERATURE_LEN + 1;
 | 
			
		||||
        private static final long DANGER_ZONE_LENGTH = ((MAX_ROW_LEN - 1) / 8 * 8 + 8);
 | 
			
		||||
 | 
			
		||||
        private void processChunk() {
 | 
			
		||||
            final long inputSize = this.inputSize;
 | 
			
		||||
            final long inputBase = this.inputBase;
 | 
			
		||||
            long cursor = 0;
 | 
			
		||||
            long lastNameWord;
 | 
			
		||||
            while (cursor < inputSize) {
 | 
			
		||||
                boolean withinSafeZone;
 | 
			
		||||
                long word1;
 | 
			
		||||
                long word2;
 | 
			
		||||
                long nameLen;
 | 
			
		||||
                long nameStartAddress = inputBase + cursor;
 | 
			
		||||
                if (cursor + DANGER_ZONE_LENGTH <= inputSize) {
 | 
			
		||||
                    withinSafeZone = true;
 | 
			
		||||
                    word1 = UNSAFE.getLong(nameStartAddress);
 | 
			
		||||
                    word2 = UNSAFE.getLong(nameStartAddress + Long.BYTES);
 | 
			
		||||
                    nameLen = nameLen(word1, word2, withinSafeZone);
 | 
			
		||||
                    word1 = maskWord(word1, nameLen);
 | 
			
		||||
                    word2 = maskWord(word2, nameLen - Long.BYTES);
 | 
			
		||||
                long nameWord0 = UNSAFE.getLong(nameStartAddress);
 | 
			
		||||
                long nameWord1 = 0;
 | 
			
		||||
                long matchBits = semicolonMatchBits(nameWord0);
 | 
			
		||||
                long hash;
 | 
			
		||||
                int nameLen;
 | 
			
		||||
                int temperature;
 | 
			
		||||
                if (matchBits != 0) {
 | 
			
		||||
                    nameLen = nameLen(matchBits);
 | 
			
		||||
                    nameWord0 = maskWord(nameWord0, matchBits);
 | 
			
		||||
                    cursor += nameLen;
 | 
			
		||||
                    long tempWord = UNSAFE.getLong(inputBase + cursor);
 | 
			
		||||
                    int dotPos = dotPos(tempWord);
 | 
			
		||||
                    temperature = parseTemperature(tempWord, dotPos);
 | 
			
		||||
                    cursor += (dotPos >> 3) + 3;
 | 
			
		||||
                    hash = hash(nameWord0);
 | 
			
		||||
                    if (stats.gotoName0(hash, nameWord0)) {
 | 
			
		||||
                        stats.observe(temperature);
 | 
			
		||||
                        continue;
 | 
			
		||||
                    }
 | 
			
		||||
                    lastNameWord = nameWord0;
 | 
			
		||||
                }
 | 
			
		||||
                else {
 | 
			
		||||
                    withinSafeZone = false;
 | 
			
		||||
                    UNSAFE.putLong(nameBufBase, 0);
 | 
			
		||||
                    UNSAFE.putLong(nameBufBase + Long.BYTES, 0);
 | 
			
		||||
                    UNSAFE.copyMemory(nameStartAddress, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor));
 | 
			
		||||
                    word1 = UNSAFE.getLong(nameBufBase);
 | 
			
		||||
                    word2 = UNSAFE.getLong(nameBufBase + Long.BYTES);
 | 
			
		||||
                    nameLen = nameLen(word1, word2, withinSafeZone);
 | 
			
		||||
                else { // nameLen > 8
 | 
			
		||||
                    hash = hash(nameWord0);
 | 
			
		||||
                    nameWord1 = UNSAFE.getLong(nameStartAddress + Long.BYTES);
 | 
			
		||||
                    matchBits = semicolonMatchBits(nameWord1);
 | 
			
		||||
                    if (matchBits != 0) {
 | 
			
		||||
                        nameLen = Long.BYTES + nameLen(matchBits);
 | 
			
		||||
                        nameWord1 = maskWord(nameWord1, matchBits);
 | 
			
		||||
                        cursor += nameLen;
 | 
			
		||||
                        long tempWord = UNSAFE.getLong(inputBase + cursor);
 | 
			
		||||
                        int dotPos = dotPos(tempWord);
 | 
			
		||||
                        temperature = parseTemperature(tempWord, dotPos);
 | 
			
		||||
                        cursor += (dotPos >> 3) + 3;
 | 
			
		||||
                        if (stats.gotoName1(hash, nameWord0, nameWord1)) {
 | 
			
		||||
                            stats.observe(temperature);
 | 
			
		||||
                            continue;
 | 
			
		||||
                        }
 | 
			
		||||
                        lastNameWord = nameWord1;
 | 
			
		||||
                    }
 | 
			
		||||
                    else { // nameLen > 16
 | 
			
		||||
                        nameLen = 2 * Long.BYTES;
 | 
			
		||||
                        while (true) {
 | 
			
		||||
                            lastNameWord = UNSAFE.getLong(nameStartAddress + nameLen);
 | 
			
		||||
                            matchBits = semicolonMatchBits(lastNameWord);
 | 
			
		||||
                            if (matchBits != 0) {
 | 
			
		||||
                                nameLen += nameLen(matchBits);
 | 
			
		||||
                                lastNameWord = maskWord(lastNameWord, matchBits);
 | 
			
		||||
                                cursor += nameLen;
 | 
			
		||||
                                long tempWord = UNSAFE.getLong(inputBase + cursor);
 | 
			
		||||
                                int dotPos = dotPos(tempWord);
 | 
			
		||||
                                temperature = parseTemperature(tempWord, dotPos);
 | 
			
		||||
                                cursor += (dotPos >> 3) + 3;
 | 
			
		||||
                                break;
 | 
			
		||||
                            }
 | 
			
		||||
                            nameLen += Long.BYTES;
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                long hash = hash(word1);
 | 
			
		||||
                assert nameLen > 0 && nameLen <= 100 : nameLen;
 | 
			
		||||
                long tempStartAddress = nameStartAddress + nameLen + 1;
 | 
			
		||||
                int temperature = withinSafeZone
 | 
			
		||||
                        ? parseTemperatureSwarAndAdvanceCursor(tempStartAddress)
 | 
			
		||||
                        : parseTemperatureSimpleAndAdvanceCursor(tempStartAddress);
 | 
			
		||||
                updateStats(hash, nameStartAddress, nameLen, word1, word2, temperature, withinSafeZone);
 | 
			
		||||
                stats.gotoAndObserve(hash, nameStartAddress, nameLen, nameWord0, nameWord1, lastNameWord, temperature);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private void updateStats(
 | 
			
		||||
                                 long hash, long nameStartAddress, long nameLen, long nameWord1, long nameWord2,
 | 
			
		||||
                                 int temperature, boolean withinSafeZone) {
 | 
			
		||||
            int tableIndex = (int) (hash & TABLE_INDEX_MASK);
 | 
			
		||||
            while (true) {
 | 
			
		||||
                stats.gotoIndex(tableIndex);
 | 
			
		||||
                if (stats.hash() == hash && stats.nameLen() == nameLen && nameEquals(
 | 
			
		||||
                        stats.nameAddress(), nameStartAddress, nameLen, nameWord1, nameWord2, withinSafeZone)) {
 | 
			
		||||
                    stats.setSum(stats.sum() + temperature);
 | 
			
		||||
                    stats.setCount(stats.count() + 1);
 | 
			
		||||
                    stats.setMin((short) Integer.min(stats.min(), temperature));
 | 
			
		||||
                    stats.setMax((short) Integer.max(stats.max(), temperature));
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                if (stats.nameLen() != 0) {
 | 
			
		||||
                    tableIndex = (tableIndex + 1) & TABLE_INDEX_MASK;
 | 
			
		||||
                    continue;
 | 
			
		||||
                }
 | 
			
		||||
                stats.setHash(hash);
 | 
			
		||||
                stats.setNameLen((int) nameLen);
 | 
			
		||||
                stats.setSum(temperature);
 | 
			
		||||
                stats.setCount(1);
 | 
			
		||||
                stats.setMin((short) temperature);
 | 
			
		||||
                stats.setMax((short) temperature);
 | 
			
		||||
                UNSAFE.copyMemory(nameStartAddress, stats.nameAddress(), nameLen);
 | 
			
		||||
                return;
 | 
			
		||||
            }
 | 
			
		||||
        private static final long BROADCAST_SEMICOLON = 0x3B3B3B3B3B3B3B3BL;
 | 
			
		||||
        private static final long BROADCAST_0x01 = 0x0101010101010101L;
 | 
			
		||||
        private static final long BROADCAST_0x80 = 0x8080808080808080L;
 | 
			
		||||
 | 
			
		||||
        private static long semicolonMatchBits(long word) {
 | 
			
		||||
            long diff = word ^ BROADCAST_SEMICOLON;
 | 
			
		||||
            return (diff - BROADCAST_0x01) & (~diff & BROADCAST_0x80);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Credit: merykitty
 | 
			
		||||
        private int parseTemperatureSwarAndAdvanceCursor(long tempStartAddress) {
 | 
			
		||||
            long word = UNSAFE.getLong(tempStartAddress);
 | 
			
		||||
            final long negated = ~word;
 | 
			
		||||
            final int dotPos = Long.numberOfTrailingZeros(negated & 0x10101000);
 | 
			
		||||
            cursor = (tempStartAddress + (dotPos / 8) + 3) - inputBase;
 | 
			
		||||
            final long signed = (negated << 59) >> 63;
 | 
			
		||||
        // credit: artsiomkorzun
 | 
			
		||||
        private static long maskWord(long word, long matchBits) {
 | 
			
		||||
            long mask = matchBits ^ (matchBits - 1);
 | 
			
		||||
            return word & mask;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // credit: merykitty
 | 
			
		||||
        private static int dotPos(long word) {
 | 
			
		||||
            return Long.numberOfTrailingZeros(~word & 0x10101000);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // credit: merykitty
 | 
			
		||||
        private static int parseTemperature(long word, int dotPos) {
 | 
			
		||||
            final long signed = (~word << 59) >> 63;
 | 
			
		||||
            final long removeSignMask = ~(signed & 0xFF);
 | 
			
		||||
            final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L;
 | 
			
		||||
            final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
 | 
			
		||||
            return (int) ((absValue ^ signed) - signed);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private int parseTemperatureSimpleAndAdvanceCursor(long tempStartAddress) {
 | 
			
		||||
            final byte minus = (byte) '-';
 | 
			
		||||
            final byte zero = (byte) '0';
 | 
			
		||||
            final byte dot = (byte) '.';
 | 
			
		||||
 | 
			
		||||
            byte ch = UNSAFE.getByte(tempStartAddress);
 | 
			
		||||
            long address = tempStartAddress;
 | 
			
		||||
            int temperature;
 | 
			
		||||
            int sign;
 | 
			
		||||
            if (ch == minus) {
 | 
			
		||||
                sign = -1;
 | 
			
		||||
                address++;
 | 
			
		||||
                ch = UNSAFE.getByte(address);
 | 
			
		||||
            }
 | 
			
		||||
            else {
 | 
			
		||||
                sign = 1;
 | 
			
		||||
            }
 | 
			
		||||
            temperature = ch - zero;
 | 
			
		||||
            address++;
 | 
			
		||||
            ch = UNSAFE.getByte(address);
 | 
			
		||||
            if (ch == dot) {
 | 
			
		||||
                address++;
 | 
			
		||||
                ch = UNSAFE.getByte(address);
 | 
			
		||||
            }
 | 
			
		||||
            else {
 | 
			
		||||
                temperature = 10 * temperature + (ch - zero);
 | 
			
		||||
                address += 2;
 | 
			
		||||
                ch = UNSAFE.getByte(address);
 | 
			
		||||
            }
 | 
			
		||||
            temperature = 10 * temperature + (ch - zero);
 | 
			
		||||
            // address - inputBase is the length of the temperature field.
 | 
			
		||||
            // A newline character follows the temperature, and so we advance
 | 
			
		||||
            // the cursor past the newline to the start of the next line.
 | 
			
		||||
            cursor = (address + 2) - inputBase;
 | 
			
		||||
            return sign * temperature;
 | 
			
		||||
        private static int nameLen(long separator) {
 | 
			
		||||
            return (Long.numberOfTrailingZeros(separator) >>> 3) + 1;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private static long hash(long word1) {
 | 
			
		||||
            long seed = 0x51_7c_c1_b7_27_22_0a_95L;
 | 
			
		||||
            int rotDist = 17;
 | 
			
		||||
 | 
			
		||||
            long hash = word1;
 | 
			
		||||
            hash *= seed;
 | 
			
		||||
            hash = Long.rotateLeft(hash, rotDist);
 | 
			
		||||
            // hash ^= word2;
 | 
			
		||||
            // hash *= seed;
 | 
			
		||||
            // hash = Long.rotateLeft(hash, rotDist);
 | 
			
		||||
            return hash;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2,
 | 
			
		||||
                                          boolean withinSafeZone) {
 | 
			
		||||
            boolean mismatch1 = maskWord(inputWord1, len) != UNSAFE.getLong(statsAddr);
 | 
			
		||||
            boolean mismatch2 = maskWord(inputWord2, len - Long.BYTES) != UNSAFE.getLong(statsAddr + Long.BYTES);
 | 
			
		||||
            if (len <= 2 * Long.BYTES) {
 | 
			
		||||
                return !(mismatch1 | mismatch2);
 | 
			
		||||
            }
 | 
			
		||||
            if (withinSafeZone) {
 | 
			
		||||
                int i = 2 * Long.BYTES;
 | 
			
		||||
                for (; i <= len - Long.BYTES; i += Long.BYTES) {
 | 
			
		||||
                    if (UNSAFE.getLong(inputAddr + i) != UNSAFE.getLong(statsAddr + i)) {
 | 
			
		||||
                        return false;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                return maskWord(UNSAFE.getLong(inputAddr + i), len - i) == UNSAFE.getLong(statsAddr + i);
 | 
			
		||||
            }
 | 
			
		||||
            else {
 | 
			
		||||
                for (int i = 2 * Long.BYTES; i < len; i++) {
 | 
			
		||||
                    if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) {
 | 
			
		||||
                        return false;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            return true;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private static long maskWord(long word, long len) {
 | 
			
		||||
            long halfShiftDistance = Long.max(0, Long.BYTES - len) << 2;
 | 
			
		||||
            long mask = (~0L >>> halfShiftDistance) >>> halfShiftDistance; // avoid Java trap of shiftDist % 64
 | 
			
		||||
            return word & mask;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private static final long BROADCAST_0x01 = broadcastByte(0x01);
 | 
			
		||||
        private static final long BROADCAST_0x80 = broadcastByte(0x80);
 | 
			
		||||
 | 
			
		||||
        // Adapted from https://jameshfisher.com/2017/01/24/bitwise-check-for-zero-byte/
 | 
			
		||||
        // and https://github.com/ashvardanian/StringZilla/blob/14e7a78edcc16b031c06b375aac1f66d8f19d45a/stringzilla/stringzilla.h#L139-L169
 | 
			
		||||
        long nameLen(long word1, long word2, boolean withinSafeZone) {
 | 
			
		||||
            {
 | 
			
		||||
                long matchBits1 = matchBits(word1);
 | 
			
		||||
                long matchBits2 = matchBits(word2);
 | 
			
		||||
                if ((matchBits1 | matchBits2) != 0) {
 | 
			
		||||
                    int trailing1 = Long.numberOfTrailingZeros(matchBits1);
 | 
			
		||||
                    int match1IsNonZero = trailing1 & 63;
 | 
			
		||||
                    match1IsNonZero |= match1IsNonZero >>> 3;
 | 
			
		||||
                    match1IsNonZero |= match1IsNonZero >>> 1;
 | 
			
		||||
                    match1IsNonZero |= match1IsNonZero >>> 1;
 | 
			
		||||
                    // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to
 | 
			
		||||
                    // raise the lowest bit in trailing2 if trailing1 is nonzero. This forces
 | 
			
		||||
                    // trailing2 to be zero if trailing1 is non-zero.
 | 
			
		||||
                    int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63;
 | 
			
		||||
                    // trailing1 | trailing2 works like trailing1 + trailing2 because if trailing2 is non-zero,
 | 
			
		||||
                    // then trailing1 is 64, and since trailing2 is < 64, there's no bit overlap.
 | 
			
		||||
                    return (trailing1 | trailing2) >> 3;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            long nameStartAddress = inputBase + cursor;
 | 
			
		||||
            long address = nameStartAddress + 2 * Long.BYTES;
 | 
			
		||||
            long limit = inputBase + inputSize;
 | 
			
		||||
            if (withinSafeZone) {
 | 
			
		||||
                for (; address < limit; address += Long.BYTES) {
 | 
			
		||||
                    var block = maskWord(UNSAFE.getLong(address), limit - address);
 | 
			
		||||
                    long matchBits = matchBits(block);
 | 
			
		||||
                    if (matchBits != 0) {
 | 
			
		||||
                        return address + (Long.numberOfTrailingZeros(matchBits) >> 3) - nameStartAddress;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                throw new RuntimeException("Semicolon not found");
 | 
			
		||||
            }
 | 
			
		||||
            return addrOfSemicolonSafe(address, limit) - nameStartAddress;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private static long addrOfSemicolonSafe(long address, long limit) {
 | 
			
		||||
            for (; address < limit - Long.BYTES + 1; address += Long.BYTES) {
 | 
			
		||||
                var block = UNSAFE.getLong(address);
 | 
			
		||||
                long matchBits = matchBits(block);
 | 
			
		||||
                if (matchBits != 0) {
 | 
			
		||||
                    return address + (Long.numberOfTrailingZeros(matchBits) >> 3);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            for (; address < limit; address++) {
 | 
			
		||||
                if (UNSAFE.getByte(address) == SEMICOLON) {
 | 
			
		||||
                    return address;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            throw new RuntimeException("Semicolon not found");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private static long matchBits(long word) {
 | 
			
		||||
            long diff = word ^ BROADCAST_SEMICOLON;
 | 
			
		||||
            return (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
 | 
			
		||||
        private static long hash(long word) {
 | 
			
		||||
            return Long.rotateLeft(word * 0x51_7c_c1_b7_27_22_0a_95L, 17);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Copies the results from native memory to Java heap and puts them into the results array.
 | 
			
		||||
@@ -403,22 +268,6 @@ public class CalculateAverage_mtopolnik {
 | 
			
		||||
            Arrays.sort(exported);
 | 
			
		||||
            results[myIndex] = exported;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private final ByteBuffer buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder());
 | 
			
		||||
 | 
			
		||||
        private String longToString(long word) {
 | 
			
		||||
            buf.clear();
 | 
			
		||||
            buf.putLong(word);
 | 
			
		||||
            return new String(buf.array(), StandardCharsets.UTF_8); // + "|" + Arrays.toString(buf.array());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static long broadcastByte(int b) {
 | 
			
		||||
        long nnnnnnnn = b;
 | 
			
		||||
        nnnnnnnn |= nnnnnnnn << 8;
 | 
			
		||||
        nnnnnnnn |= nnnnnnnn << 16;
 | 
			
		||||
        nnnnnnnn |= nnnnnnnn << 32;
 | 
			
		||||
        return nnnnnnnn;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static class StatsAccessor {
 | 
			
		||||
@@ -446,6 +295,16 @@ public class CalculateAverage_mtopolnik {
 | 
			
		||||
            slotBase = address + index * SIZEOF;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private boolean gotoName0(long hash, long nameWord0) {
 | 
			
		||||
            gotoIndex((int) (hash & TABLE_INDEX_MASK));
 | 
			
		||||
            return hash() == hash && nameWord0() == nameWord0;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private boolean gotoName1(long hash, long nameWord0, long nameWord1) {
 | 
			
		||||
            gotoIndex((int) (hash & TABLE_INDEX_MASK));
 | 
			
		||||
            return hash() == hash && nameWord0() == nameWord0 && nameWord1() == nameWord1;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        long hash() {
 | 
			
		||||
            return UNSAFE.getLong(slotBase + HASH_OFFSET);
 | 
			
		||||
        }
 | 
			
		||||
@@ -474,9 +333,17 @@ public class CalculateAverage_mtopolnik {
 | 
			
		||||
            return slotBase + NAME_OFFSET;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        long nameWord0() {
 | 
			
		||||
            return UNSAFE.getLong(nameAddress());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        long nameWord1() {
 | 
			
		||||
            return UNSAFE.getLong(nameAddress() + Long.BYTES);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        String exportNameString() {
 | 
			
		||||
            final var bytes = new byte[nameLen()];
 | 
			
		||||
            UNSAFE.copyMemory(null, nameAddress(), bytes, ARRAY_BASE_OFFSET, nameLen());
 | 
			
		||||
            final var bytes = new byte[nameLen() - 1];
 | 
			
		||||
            UNSAFE.copyMemory(null, nameAddress(), bytes, ARRAY_BASE_OFFSET, bytes.length);
 | 
			
		||||
            return new String(bytes, StandardCharsets.UTF_8);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -503,6 +370,59 @@ public class CalculateAverage_mtopolnik {
 | 
			
		||||
        void setMax(short max) {
 | 
			
		||||
            UNSAFE.putShort(slotBase + MAX_OFFSET, max);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        void gotoAndObserve(
 | 
			
		||||
                            long hash, long nameStartAddress, int nameLen, long nameWord0, long nameWord1, long lastNameWord,
 | 
			
		||||
                            int temperature) {
 | 
			
		||||
            int tableIndex = (int) (hash & TABLE_INDEX_MASK);
 | 
			
		||||
            while (true) {
 | 
			
		||||
                gotoIndex(tableIndex);
 | 
			
		||||
                if (hash() == hash && nameLen() == nameLen && nameEquals(
 | 
			
		||||
                        nameAddress(), nameStartAddress, nameLen, nameWord0, nameWord1, lastNameWord)) {
 | 
			
		||||
                    observe(temperature);
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
                if (nameLen() != 0) {
 | 
			
		||||
                    tableIndex = (tableIndex + 1) & TABLE_INDEX_MASK;
 | 
			
		||||
                    continue;
 | 
			
		||||
                }
 | 
			
		||||
                initialize(hash, nameLen, nameStartAddress, temperature);
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        void initialize(long hash, long nameLen, long nameStartAddress, int temperature) {
 | 
			
		||||
            setHash(hash);
 | 
			
		||||
            setNameLen((int) nameLen);
 | 
			
		||||
            setSum(temperature);
 | 
			
		||||
            setCount(1);
 | 
			
		||||
            setMin((short) temperature);
 | 
			
		||||
            setMax((short) temperature);
 | 
			
		||||
            UNSAFE.copyMemory(nameStartAddress, nameAddress(), nameLen);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        void observe(int temperature) {
 | 
			
		||||
            setSum(sum() + temperature);
 | 
			
		||||
            setCount(count() + 1);
 | 
			
		||||
            setMin((short) Integer.min(min(), temperature));
 | 
			
		||||
            setMax((short) Integer.max(max(), temperature));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private static boolean nameEquals(
 | 
			
		||||
                                          long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2, long lastInputWord) {
 | 
			
		||||
            boolean mismatch1 = inputWord1 != UNSAFE.getLong(statsAddr);
 | 
			
		||||
            boolean mismatch2 = inputWord2 != UNSAFE.getLong(statsAddr + Long.BYTES);
 | 
			
		||||
            if (len <= 2 * Long.BYTES) {
 | 
			
		||||
                return !(mismatch1 | mismatch2);
 | 
			
		||||
            }
 | 
			
		||||
            int i = 2 * Long.BYTES;
 | 
			
		||||
            for (; i <= len - Long.BYTES; i += Long.BYTES) {
 | 
			
		||||
                if (UNSAFE.getLong(inputAddr + i) != UNSAFE.getLong(statsAddr + i)) {
 | 
			
		||||
                    return false;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            return i == len || lastInputWord == UNSAFE.getLong(statsAddr + i);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static void mergeSortAndPrint(StationStats[][] results) {
 | 
			
		||||
@@ -556,4 +476,34 @@ public class CalculateAverage_mtopolnik {
 | 
			
		||||
        }
 | 
			
		||||
        System.out.println('}');
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static class StationStats implements Comparable<StationStats> {
 | 
			
		||||
        String name;
 | 
			
		||||
        long sum;
 | 
			
		||||
        int count;
 | 
			
		||||
        int min;
 | 
			
		||||
        int max;
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public String toString() {
 | 
			
		||||
            return String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum / count) / 10.0, max / 10.0);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public boolean equals(Object that) {
 | 
			
		||||
            return that.getClass() == StationStats.class && ((StationStats) that).name.equals(this.name);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public int compareTo(StationStats that) {
 | 
			
		||||
            return name.compareTo(that.name);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static String longToString(long word) {
 | 
			
		||||
        final ByteBuffer buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder());
 | 
			
		||||
        buf.clear();
 | 
			
		||||
        buf.putLong(word);
 | 
			
		||||
        return new String(buf.array(), StandardCharsets.UTF_8);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user