From 95459f56407c5c75e7ef08500eb711e5c94654ab Mon Sep 17 00:00:00 2001 From: Marko Topolnik Date: Thu, 11 Jan 2024 20:02:14 +0100 Subject: [PATCH] Entry into the contest, calculate_average_mtopolnik.sh (#246) * calculate_average_mtopolnik * short hash (just first 8 bytes of name) * Remove unneeded checks * Remove archiving classes * 2x larger hashtable * Add "set" to setters * Simplify parsing temperature, remove newline search * Reduce the size of the name slot * Store name length and use to detect collision * Reduce memory loads in parseTemperature * Use short for min/max * Extract constant for semicolon * Fix script header * Explicit bash shell in shebang * Inline usage of broadcast semicolon * Try vectorization * Remove vectorization * Go Unsafe * Use SWAR temperature parsing by merykitty * Inline some things * Remove commented-out MemorySegment usage * Inline namesMem.asSlice() invocation * Try out JVM JIT flags * Implement strcmp * Remove unused instance variables * Optimize hashing * Put station name into hashtable * Reorder method * Remove usage of MemorySegment.getUtf8String Replace with UNSAFE.copyMemory() and new String() * Fix hashing bug * Remove outdated comments * Fix informative constants * Use broadcastByte() more * Improve method naming * More hashing * Revert more hashing * Add commented-out code to hash 16 bytes * Slight cleanup * Align hashtable at cacheline boundary * Add Graal Native image * Revert Graal Native image This reverts commit d916a42326d89bd1a841bbbecfae185adb8679d7. * Simplify shell script (no SDK selection) * Move a constant, zero out hashtable on start * Better name comparison * Add prepare_mtopolnik.sh * Cleaner idiom in name comparison * AND instead of MOD for hashtable indexing * Improve word masking code * Fix formatting * Reduce memory loads * Remove endianness checks * Avoid hash == 0 problem * Fix subtle bug * MergeSort of parellel results * Touch up perf * Touch up perf * Remove -Xmx256m * Extract result printing method * Print allocation details on OOME * Single mmap * Use global allocation arena --- calculate_average_mtopolnik.sh | 21 + prepare_mtopolnik.sh | 19 + .../onebrc/CalculateAverage_mtopolnik.java | 530 ++++++++++++++++++ 3 files changed, 570 insertions(+) create mode 100755 calculate_average_mtopolnik.sh create mode 100755 prepare_mtopolnik.sh create mode 100644 src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java diff --git a/calculate_average_mtopolnik.sh b/calculate_average_mtopolnik.sh new file mode 100755 index 0000000..e48711a --- /dev/null +++ b/calculate_average_mtopolnik.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_mtopolnik::recordMeasurementAndAdvanceCursor" +# -XX:InlineSmallCode=10000 -XX:-TieredCompilation -XX:CICompilerCount=2 -XX:CompileThreshold=1000\ +java --enable-preview \ + --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_mtopolnik diff --git a/prepare_mtopolnik.sh b/prepare_mtopolnik.sh new file mode 100755 index 0000000..a705f17 --- /dev/null +++ b/prepare_mtopolnik.sh @@ -0,0 +1,19 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java new file mode 100644 index 0000000..fe487fc --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java @@ -0,0 +1,530 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import sun.misc.Unsafe; + +import java.io.File; +import java.io.RandomAccessFile; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel.MapMode; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; + +public class CalculateAverage_mtopolnik { + private static final Unsafe UNSAFE = unsafe(); + private static final int MAX_NAME_LEN = 100; + private static final int STATS_TABLE_SIZE = 1 << 16; + private static final int TABLE_INDEX_MASK = STATS_TABLE_SIZE - 1; + private static final String MEASUREMENTS_TXT = "measurements.txt"; + private static final byte SEMICOLON = ';'; + private static final long BROADCAST_SEMICOLON = broadcastByte(SEMICOLON); + + // These two are just informative, I let the IDE calculate them for me + private static final long NATIVE_MEM_PER_THREAD = StatsAccessor.SIZEOF * STATS_TABLE_SIZE; + private static final long NATIVE_MEM_ON_8_THREADS = 8 * NATIVE_MEM_PER_THREAD; + + private static Unsafe unsafe() { + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + return (Unsafe) theUnsafe.get(Unsafe.class); + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + static class StationStats implements Comparable { + String name; + long sum; + int count; + int min; + int max; + + @Override + public String toString() { + return String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum / count) / 10.0, max / 10.0); + } + + @Override + public boolean equals(Object that) { + return that.getClass() == StationStats.class && ((StationStats) that).name.equals(this.name); + } + + @Override + public int compareTo(StationStats that) { + return name.compareTo(that.name); + } + } + + public static void main(String[] args) throws Exception { + calculate(); + } + + static void calculate() throws Exception { + final File file = new File(MEASUREMENTS_TXT); + final long length = file.length(); + final int chunkCount = Runtime.getRuntime().availableProcessors(); + final var results = new StationStats[chunkCount][]; + final var chunkStartOffsets = new long[chunkCount]; + try (var raf = new RandomAccessFile(file, "r")) { + final var inputBase = raf.getChannel().map(MapMode.READ_ONLY, 0, length, Arena.global()).address(); + for (int i = 1; i < chunkStartOffsets.length; i++) { + var start = length * i / chunkStartOffsets.length; + raf.seek(start); + while (raf.read() != (byte) '\n') { + } + start = raf.getFilePointer(); + chunkStartOffsets[i] = start; + } + var threads = new Thread[chunkCount]; + for (int i = 0; i < chunkCount; i++) { + final long chunkStart = chunkStartOffsets[i]; + final long chunkLimit = (i + 1 < chunkCount) ? chunkStartOffsets[i + 1] : length; + threads[i] = new Thread(new ChunkProcessor(inputBase + chunkStart, inputBase + chunkLimit, results, i)); + } + for (var thread : threads) { + thread.start(); + } + for (var thread : threads) { + thread.join(); + } + } + mergeSortAndPrint(results); + } + + private static class ChunkProcessor implements Runnable { + private static final long NAMEBUF_SIZE = 2 * Long.BYTES; + private static final int CACHELINE_SIZE = 64; + + private final long inputBase; + private final long inputSize; + private final StationStats[][] results; + private final int myIndex; + + private StatsAccessor stats; + private long nameBufBase; + private long cursor; + + ChunkProcessor(long chunkStart, long chunkLimit, StationStats[][] results, int myIndex) { + this.inputBase = chunkStart; + this.inputSize = chunkLimit - chunkStart; + this.results = results; + this.myIndex = myIndex; + } + + @Override + public void run() { + try (Arena confinedArena = Arena.ofConfined()) { + long totalAllocated = 0; + String threadName = Thread.currentThread().getName(); + long statsByteSize = STATS_TABLE_SIZE * StatsAccessor.SIZEOF; + var diagnosticString = String.format("Thread %s needs %,d bytes, managed to allocate before OOM: ", + threadName, statsByteSize + NAMEBUF_SIZE); + try { + stats = new StatsAccessor(confinedArena.allocate(statsByteSize, CACHELINE_SIZE)); + totalAllocated = statsByteSize; + nameBufBase = confinedArena.allocate(NAMEBUF_SIZE).address(); + } + catch (OutOfMemoryError e) { + System.err.print(diagnosticString); + System.err.println(totalAllocated); + throw e; + } + processChunk(); + exportResults(); + } + } + + private void processChunk() { + while (cursor < inputSize) { + long word1; + long word2; + if (cursor + 2 * Long.BYTES <= inputSize) { + word1 = UNSAFE.getLong(inputBase + cursor); + word2 = UNSAFE.getLong(inputBase + cursor + Long.BYTES); + } + else { + UNSAFE.putLong(nameBufBase, 0); + UNSAFE.putLong(nameBufBase + Long.BYTES, 0); + UNSAFE.copyMemory(inputBase + cursor, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor)); + word1 = UNSAFE.getLong(nameBufBase); + word2 = UNSAFE.getLong(nameBufBase + Long.BYTES); + } + long posOfSemicolon = posOfSemicolon(word1, word2); + word1 = maskWord(word1, posOfSemicolon - cursor); + word2 = maskWord(word2, posOfSemicolon - cursor - Long.BYTES); + long hash = hash(word1); + long namePos = cursor; + long nameLen = posOfSemicolon - cursor; + assert nameLen <= 100 : "nameLen > 100"; + int temperature = parseTemperatureAndAdvanceCursor(posOfSemicolon); + updateStats(hash, namePos, nameLen, word1, word2, temperature); + } + } + + private void updateStats(long hash, long namePos, long nameLen, long nameWord1, long nameWord2, int temperature) { + int tableIndex = (int) (hash & TABLE_INDEX_MASK); + while (true) { + stats.gotoIndex(tableIndex); + if (stats.hash() == hash && stats.nameLen() == nameLen + && nameEquals(stats.nameAddress(), inputBase + namePos, nameLen, nameWord1, nameWord2)) { + stats.setSum(stats.sum() + temperature); + stats.setCount(stats.count() + 1); + stats.setMin((short) Integer.min(stats.min(), temperature)); + stats.setMax((short) Integer.max(stats.max(), temperature)); + return; + } + if (stats.nameLen() != 0) { + tableIndex = (tableIndex + 1) & TABLE_INDEX_MASK; + continue; + } + stats.setHash(hash); + stats.setNameLen((int) nameLen); + stats.setSum(temperature); + stats.setCount(1); + stats.setMin((short) temperature); + stats.setMax((short) temperature); + UNSAFE.copyMemory(inputBase + namePos, stats.nameAddress(), nameLen); + return; + } + } + + private int parseTemperatureAndAdvanceCursor(long semicolonPos) { + long startOffset = semicolonPos + 1; + if (startOffset <= inputSize - Long.BYTES) { + return parseTemperatureSwarAndAdvanceCursor(startOffset); + } + return parseTemperatureSimpleAndAdvanceCursor(startOffset); + } + + // Credit: merykitty + private int parseTemperatureSwarAndAdvanceCursor(long startOffset) { + long word = UNSAFE.getLong(inputBase + startOffset); + final long negated = ~word; + final int dotPos = Long.numberOfTrailingZeros(negated & 0x10101000); + final long signed = (negated << 59) >> 63; + final long removeSignMask = ~(signed & 0xFF); + final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L; + final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; + final int temperature = (int) ((absValue ^ signed) - signed); + cursor = startOffset + (dotPos / 8) + 3; + return temperature; + } + + private int parseTemperatureSimpleAndAdvanceCursor(long startOffset) { + final byte minus = (byte) '-'; + final byte zero = (byte) '0'; + final byte dot = (byte) '.'; + + // Temperature plus the following newline is at least 4 chars, so this is always safe: + int fourCh = UNSAFE.getInt(inputBase + startOffset); + final int mask = 0xFF; + byte ch = (byte) (fourCh & mask); + int shift = 0; + int temperature; + int sign; + if (ch == minus) { + sign = -1; + shift += 8; + ch = (byte) ((fourCh & (mask << shift)) >>> shift); + } + else { + sign = 1; + } + temperature = ch - zero; + shift += 8; + ch = (byte) ((fourCh & (mask << shift)) >>> shift); + if (ch == dot) { + shift += 8; + ch = (byte) ((fourCh & (mask << shift)) >>> shift); + } + else { + temperature = 10 * temperature + (ch - zero); + shift += 16; + // The last character may be past the four loaded bytes, load it from memory. + // Checking that with another `if` is self-defeating for performance. + ch = UNSAFE.getByte(inputBase + startOffset + (shift / 8)); + } + temperature = 10 * temperature + (ch - zero); + // `shift` holds the number of bits in the temperature field. + // A newline character follows the temperature, and so we advance + // the cursor past the newline to the start of the next line. + cursor = startOffset + (shift / 8) + 2; + return sign * temperature; + } + + private static long hash(long word1) { + long seed = 0x51_7c_c1_b7_27_22_0a_95L; + int rotDist = 17; + + long hash = word1; + hash *= seed; + hash = Long.rotateLeft(hash, rotDist); + // hash ^= word2; + // hash *= seed; + // hash = Long.rotateLeft(hash, rotDist); + return hash; + } + + private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2) { + boolean mismatch1 = maskWord(inputWord1, len) != UNSAFE.getLong(statsAddr); + boolean mismatch2 = maskWord(inputWord2, len - Long.BYTES) != UNSAFE.getLong(statsAddr + Long.BYTES); + if (mismatch1 | mismatch2) { + return false; + } + for (int i = 2 * Long.BYTES; i < len; i++) { + if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) { + return false; + } + } + return true; + } + + private static long maskWord(long word, long len) { + long halfShiftDistance = Long.max(0, Long.BYTES - len) << 2; + long mask = (~0L >>> halfShiftDistance) >>> halfShiftDistance; // avoid Java trap of shiftDist % 64 + return word & mask; + } + + private static final long BROADCAST_0x01 = broadcastByte(0x01); + private static final long BROADCAST_0x80 = broadcastByte(0x80); + + // Adapted from https://jameshfisher.com/2017/01/24/bitwise-check-for-zero-byte/ + // and https://github.com/ashvardanian/StringZilla/blob/14e7a78edcc16b031c06b375aac1f66d8f19d45a/stringzilla/stringzilla.h#L139-L169 + long posOfSemicolon(long word1, long word2) { + long diff = word1 ^ BROADCAST_SEMICOLON; + long matchBits1 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; + diff = word2 ^ BROADCAST_SEMICOLON; + long matchBits2 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; + if ((matchBits1 | matchBits2) != 0) { + int trailing1 = Long.numberOfTrailingZeros(matchBits1); + int match1IsNonZero = trailing1 & 63; + match1IsNonZero |= match1IsNonZero >>> 3; + match1IsNonZero |= match1IsNonZero >>> 1; + match1IsNonZero |= match1IsNonZero >>> 1; + // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to + // raise the lowest bit in traling2 if trailing1 is nonzero. This forces + // trailing2 to be zero if trailing1 is non-zero. + int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63; + return cursor + ((trailing1 | trailing2) >> 3); + } + long offset = cursor + 2 * Long.BYTES; + for (; offset <= inputSize - Long.BYTES; offset += Long.BYTES) { + var block = UNSAFE.getLong(inputBase + offset); + diff = block ^ BROADCAST_SEMICOLON; + long matchBits = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; + if (matchBits != 0) { + return offset + Long.numberOfTrailingZeros(matchBits) / 8; + } + } + return posOfSemicolonSimple(offset); + } + + private long posOfSemicolonSimple(long offset) { + for (; offset < inputSize; offset++) { + if (UNSAFE.getByte(inputBase + offset) == SEMICOLON) { + return offset; + } + } + throw new RuntimeException("Semicolon not found"); + } + + // Copies the results from native memory to Java heap and puts them into the results array. + private void exportResults() { + var exportedStats = new ArrayList(10_000); + for (int i = 0; i < STATS_TABLE_SIZE; i++) { + stats.gotoIndex(i); + if (stats.nameLen() == 0) { + continue; + } + var sum = stats.sum(); + var count = stats.count(); + var min = stats.min(); + var max = stats.max(); + var name = stats.exportNameString(); + var stationStats = new StationStats(); + stationStats.name = name; + stationStats.sum = sum; + stationStats.count = count; + stationStats.min = min; + stationStats.max = max; + exportedStats.add(stationStats); + } + StationStats[] exported = exportedStats.toArray(new StationStats[0]); + Arrays.sort(exported); + results[myIndex] = exported; + } + + private final ByteBuffer buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder()); + + private String longToString(long word) { + buf.clear(); + buf.putLong(word); + return new String(buf.array(), StandardCharsets.UTF_8); // + "|" + Arrays.toString(buf.array()); + } + } + + private static long broadcastByte(int b) { + long nnnnnnnn = b; + nnnnnnnn |= nnnnnnnn << 8; + nnnnnnnn |= nnnnnnnn << 16; + nnnnnnnn |= nnnnnnnn << 32; + return nnnnnnnn; + } + + static class StatsAccessor { + static final int NAME_SLOT_SIZE = 104; + static final long HASH_OFFSET = 0; + static final long NAMELEN_OFFSET = HASH_OFFSET + Long.BYTES; + static final long SUM_OFFSET = NAMELEN_OFFSET + Integer.BYTES; + static final long COUNT_OFFSET = SUM_OFFSET + Integer.BYTES; + static final long MIN_OFFSET = COUNT_OFFSET + Integer.BYTES; + static final long MAX_OFFSET = MIN_OFFSET + Short.BYTES; + static final long NAME_OFFSET = MAX_OFFSET + Short.BYTES; + static final long SIZEOF = (NAME_OFFSET + NAME_SLOT_SIZE - 1) / 8 * 8 + 8; + + static final int ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + + private final long address; + private long slotBase; + + StatsAccessor(MemorySegment memSeg) { + memSeg.fill((byte) 0); + this.address = memSeg.address(); + } + + void gotoIndex(int index) { + slotBase = address + index * SIZEOF; + } + + long hash() { + return UNSAFE.getLong(slotBase + HASH_OFFSET); + } + + int nameLen() { + return UNSAFE.getInt(slotBase + NAMELEN_OFFSET); + } + + int sum() { + return UNSAFE.getInt(slotBase + SUM_OFFSET); + } + + int count() { + return UNSAFE.getInt(slotBase + COUNT_OFFSET); + } + + short min() { + return UNSAFE.getShort(slotBase + MIN_OFFSET); + } + + short max() { + return UNSAFE.getShort(slotBase + MAX_OFFSET); + } + + long nameAddress() { + return slotBase + NAME_OFFSET; + } + + String exportNameString() { + final var bytes = new byte[nameLen()]; + UNSAFE.copyMemory(null, nameAddress(), bytes, ARRAY_BASE_OFFSET, nameLen()); + return new String(bytes, StandardCharsets.UTF_8); + } + + void setHash(long hash) { + UNSAFE.putLong(slotBase + HASH_OFFSET, hash); + } + + void setNameLen(int nameLen) { + UNSAFE.putInt(slotBase + NAMELEN_OFFSET, nameLen); + } + + void setSum(int sum) { + UNSAFE.putInt(slotBase + SUM_OFFSET, sum); + } + + void setCount(int count) { + UNSAFE.putInt(slotBase + COUNT_OFFSET, count); + } + + void setMin(short min) { + UNSAFE.putShort(slotBase + MIN_OFFSET, min); + } + + void setMax(short max) { + UNSAFE.putShort(slotBase + MAX_OFFSET, max); + } + } + + private static void mergeSortAndPrint(StationStats[][] results) { + var onFirst = true; + System.out.print('{'); + var cursors = new int[results.length]; + var indexOfMin = 0; + StationStats curr = null; + int exhaustedCount; + while (true) { + exhaustedCount = 0; + StationStats min = null; + for (int i = 0; i < cursors.length; i++) { + if (cursors[i] == results[i].length) { + exhaustedCount++; + continue; + } + StationStats candidate = results[i][cursors[i]]; + if (min == null || min.compareTo(candidate) > 0) { + indexOfMin = i; + min = candidate; + } + } + if (exhaustedCount == cursors.length) { + if (!onFirst) { + System.out.print(", "); + } + System.out.print(curr); + break; + } + cursors[indexOfMin]++; + if (curr == null) { + curr = min; + } + else if (min.equals(curr)) { + curr.sum += min.sum; + curr.count += min.count; + curr.min = Integer.min(curr.min, min.min); + curr.max = Integer.max(curr.max, min.max); + } + else { + if (onFirst) { + onFirst = false; + } + else { + System.out.print(", "); + } + System.out.print(curr); + curr = min; + } + } + System.out.println('}'); + } +}