/* * Copyright 2023 The original authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package dev.morling.onebrc; import sun.misc.Unsafe; import java.io.File; import java.io.RandomAccessFile; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.reflect.Field; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.channels.FileChannel.MapMode; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; public class CalculateAverage_mtopolnik { private static final Unsafe UNSAFE = unsafe(); private static final int MAX_NAME_LEN = 100; private static final int STATS_TABLE_SIZE = 1 << 16; private static final int TABLE_INDEX_MASK = STATS_TABLE_SIZE - 1; private static final String MEASUREMENTS_TXT = "measurements.txt"; private static final byte SEMICOLON = ';'; private static final long BROADCAST_SEMICOLON = broadcastByte(SEMICOLON); // These two are just informative, I let the IDE calculate them for me private static final long NATIVE_MEM_PER_THREAD = StatsAccessor.SIZEOF * STATS_TABLE_SIZE; private static final long NATIVE_MEM_ON_8_THREADS = 8 * NATIVE_MEM_PER_THREAD; private static Unsafe unsafe() { try { Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); theUnsafe.setAccessible(true); return (Unsafe) theUnsafe.get(Unsafe.class); } catch (NoSuchFieldException | IllegalAccessException e) { throw new RuntimeException(e); } } static class StationStats implements Comparable { String name; long sum; int count; int min; int max; @Override public String toString() { return String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum / count) / 10.0, max / 10.0); } @Override public boolean equals(Object that) { return that.getClass() == StationStats.class && ((StationStats) that).name.equals(this.name); } @Override public int compareTo(StationStats that) { return name.compareTo(that.name); } } public static void main(String[] args) throws Exception { calculate(); } static void calculate() throws Exception { final File file = new File(MEASUREMENTS_TXT); final long length = file.length(); final int chunkCount = Runtime.getRuntime().availableProcessors(); final var results = new StationStats[chunkCount][]; final var chunkStartOffsets = new long[chunkCount]; try (var raf = new RandomAccessFile(file, "r")) { final var inputBase = raf.getChannel().map(MapMode.READ_ONLY, 0, length, Arena.global()).address(); for (int i = 1; i < chunkStartOffsets.length; i++) { var start = length * i / chunkStartOffsets.length; raf.seek(start); while (raf.read() != (byte) '\n') { } start = raf.getFilePointer(); chunkStartOffsets[i] = start; } var threads = new Thread[chunkCount]; for (int i = 0; i < chunkCount; i++) { final long chunkStart = chunkStartOffsets[i]; final long chunkLimit = (i + 1 < chunkCount) ? chunkStartOffsets[i + 1] : length; threads[i] = new Thread(new ChunkProcessor(inputBase + chunkStart, inputBase + chunkLimit, results, i)); } for (var thread : threads) { thread.start(); } for (var thread : threads) { thread.join(); } } mergeSortAndPrint(results); } private static class ChunkProcessor implements Runnable { private static final long NAMEBUF_SIZE = 2 * Long.BYTES; private static final int CACHELINE_SIZE = 64; private final long inputBase; private final long inputSize; private final StationStats[][] results; private final int myIndex; private StatsAccessor stats; private long nameBufBase; private long cursor; ChunkProcessor(long chunkStart, long chunkLimit, StationStats[][] results, int myIndex) { this.inputBase = chunkStart; this.inputSize = chunkLimit - chunkStart; this.results = results; this.myIndex = myIndex; } @Override public void run() { try (Arena confinedArena = Arena.ofConfined()) { long totalAllocated = 0; String threadName = Thread.currentThread().getName(); long statsByteSize = STATS_TABLE_SIZE * StatsAccessor.SIZEOF; var diagnosticString = String.format("Thread %s needs %,d bytes, managed to allocate before OOM: ", threadName, statsByteSize + NAMEBUF_SIZE); try { stats = new StatsAccessor(confinedArena.allocate(statsByteSize, CACHELINE_SIZE)); totalAllocated = statsByteSize; nameBufBase = confinedArena.allocate(NAMEBUF_SIZE).address(); } catch (OutOfMemoryError e) { System.err.print(diagnosticString); System.err.println(totalAllocated); throw e; } processChunk(); exportResults(); } } private void processChunk() { while (cursor < inputSize) { long word1; long word2; if (cursor + 2 * Long.BYTES <= inputSize) { word1 = UNSAFE.getLong(inputBase + cursor); word2 = UNSAFE.getLong(inputBase + cursor + Long.BYTES); } else { UNSAFE.putLong(nameBufBase, 0); UNSAFE.putLong(nameBufBase + Long.BYTES, 0); UNSAFE.copyMemory(inputBase + cursor, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor)); word1 = UNSAFE.getLong(nameBufBase); word2 = UNSAFE.getLong(nameBufBase + Long.BYTES); } long posOfSemicolon = posOfSemicolon(word1, word2); word1 = maskWord(word1, posOfSemicolon - cursor); word2 = maskWord(word2, posOfSemicolon - cursor - Long.BYTES); long hash = hash(word1); long namePos = cursor; long nameLen = posOfSemicolon - cursor; assert nameLen <= 100 : "nameLen > 100"; int temperature = parseTemperatureAndAdvanceCursor(posOfSemicolon); updateStats(hash, namePos, nameLen, word1, word2, temperature); } } private void updateStats(long hash, long namePos, long nameLen, long nameWord1, long nameWord2, int temperature) { int tableIndex = (int) (hash & TABLE_INDEX_MASK); while (true) { stats.gotoIndex(tableIndex); if (stats.hash() == hash && stats.nameLen() == nameLen && nameEquals(stats.nameAddress(), inputBase + namePos, nameLen, nameWord1, nameWord2)) { stats.setSum(stats.sum() + temperature); stats.setCount(stats.count() + 1); stats.setMin((short) Integer.min(stats.min(), temperature)); stats.setMax((short) Integer.max(stats.max(), temperature)); return; } if (stats.nameLen() != 0) { tableIndex = (tableIndex + 1) & TABLE_INDEX_MASK; continue; } stats.setHash(hash); stats.setNameLen((int) nameLen); stats.setSum(temperature); stats.setCount(1); stats.setMin((short) temperature); stats.setMax((short) temperature); UNSAFE.copyMemory(inputBase + namePos, stats.nameAddress(), nameLen); return; } } private int parseTemperatureAndAdvanceCursor(long semicolonPos) { long startOffset = semicolonPos + 1; if (startOffset <= inputSize - Long.BYTES) { return parseTemperatureSwarAndAdvanceCursor(startOffset); } return parseTemperatureSimpleAndAdvanceCursor(startOffset); } // Credit: merykitty private int parseTemperatureSwarAndAdvanceCursor(long startOffset) { long word = UNSAFE.getLong(inputBase + startOffset); final long negated = ~word; final int dotPos = Long.numberOfTrailingZeros(negated & 0x10101000); final long signed = (negated << 59) >> 63; final long removeSignMask = ~(signed & 0xFF); final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L; final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; final int temperature = (int) ((absValue ^ signed) - signed); cursor = startOffset + (dotPos / 8) + 3; return temperature; } private int parseTemperatureSimpleAndAdvanceCursor(long startOffset) { final byte minus = (byte) '-'; final byte zero = (byte) '0'; final byte dot = (byte) '.'; // Temperature plus the following newline is at least 4 chars, so this is always safe: int fourCh = UNSAFE.getInt(inputBase + startOffset); final int mask = 0xFF; byte ch = (byte) (fourCh & mask); int shift = 0; int temperature; int sign; if (ch == minus) { sign = -1; shift += 8; ch = (byte) ((fourCh & (mask << shift)) >>> shift); } else { sign = 1; } temperature = ch - zero; shift += 8; ch = (byte) ((fourCh & (mask << shift)) >>> shift); if (ch == dot) { shift += 8; ch = (byte) ((fourCh & (mask << shift)) >>> shift); } else { temperature = 10 * temperature + (ch - zero); shift += 16; // The last character may be past the four loaded bytes, load it from memory. // Checking that with another `if` is self-defeating for performance. ch = UNSAFE.getByte(inputBase + startOffset + (shift / 8)); } temperature = 10 * temperature + (ch - zero); // `shift` holds the number of bits in the temperature field. // A newline character follows the temperature, and so we advance // the cursor past the newline to the start of the next line. cursor = startOffset + (shift / 8) + 2; return sign * temperature; } private static long hash(long word1) { long seed = 0x51_7c_c1_b7_27_22_0a_95L; int rotDist = 17; long hash = word1; hash *= seed; hash = Long.rotateLeft(hash, rotDist); // hash ^= word2; // hash *= seed; // hash = Long.rotateLeft(hash, rotDist); return hash; } private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2) { boolean mismatch1 = maskWord(inputWord1, len) != UNSAFE.getLong(statsAddr); boolean mismatch2 = maskWord(inputWord2, len - Long.BYTES) != UNSAFE.getLong(statsAddr + Long.BYTES); if (mismatch1 | mismatch2) { return false; } for (int i = 2 * Long.BYTES; i < len; i++) { if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) { return false; } } return true; } private static long maskWord(long word, long len) { long halfShiftDistance = Long.max(0, Long.BYTES - len) << 2; long mask = (~0L >>> halfShiftDistance) >>> halfShiftDistance; // avoid Java trap of shiftDist % 64 return word & mask; } private static final long BROADCAST_0x01 = broadcastByte(0x01); private static final long BROADCAST_0x80 = broadcastByte(0x80); // Adapted from https://jameshfisher.com/2017/01/24/bitwise-check-for-zero-byte/ // and https://github.com/ashvardanian/StringZilla/blob/14e7a78edcc16b031c06b375aac1f66d8f19d45a/stringzilla/stringzilla.h#L139-L169 long posOfSemicolon(long word1, long word2) { long diff = word1 ^ BROADCAST_SEMICOLON; long matchBits1 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; diff = word2 ^ BROADCAST_SEMICOLON; long matchBits2 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; if ((matchBits1 | matchBits2) != 0) { int trailing1 = Long.numberOfTrailingZeros(matchBits1); int match1IsNonZero = trailing1 & 63; match1IsNonZero |= match1IsNonZero >>> 3; match1IsNonZero |= match1IsNonZero >>> 1; match1IsNonZero |= match1IsNonZero >>> 1; // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to // raise the lowest bit in traling2 if trailing1 is nonzero. This forces // trailing2 to be zero if trailing1 is non-zero. int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63; return cursor + ((trailing1 | trailing2) >> 3); } long offset = cursor + 2 * Long.BYTES; for (; offset <= inputSize - Long.BYTES; offset += Long.BYTES) { var block = UNSAFE.getLong(inputBase + offset); diff = block ^ BROADCAST_SEMICOLON; long matchBits = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; if (matchBits != 0) { return offset + Long.numberOfTrailingZeros(matchBits) / 8; } } return posOfSemicolonSimple(offset); } private long posOfSemicolonSimple(long offset) { for (; offset < inputSize; offset++) { if (UNSAFE.getByte(inputBase + offset) == SEMICOLON) { return offset; } } throw new RuntimeException("Semicolon not found"); } // Copies the results from native memory to Java heap and puts them into the results array. private void exportResults() { var exportedStats = new ArrayList(10_000); for (int i = 0; i < STATS_TABLE_SIZE; i++) { stats.gotoIndex(i); if (stats.nameLen() == 0) { continue; } var sum = stats.sum(); var count = stats.count(); var min = stats.min(); var max = stats.max(); var name = stats.exportNameString(); var stationStats = new StationStats(); stationStats.name = name; stationStats.sum = sum; stationStats.count = count; stationStats.min = min; stationStats.max = max; exportedStats.add(stationStats); } StationStats[] exported = exportedStats.toArray(new StationStats[0]); Arrays.sort(exported); results[myIndex] = exported; } private final ByteBuffer buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder()); private String longToString(long word) { buf.clear(); buf.putLong(word); return new String(buf.array(), StandardCharsets.UTF_8); // + "|" + Arrays.toString(buf.array()); } } private static long broadcastByte(int b) { long nnnnnnnn = b; nnnnnnnn |= nnnnnnnn << 8; nnnnnnnn |= nnnnnnnn << 16; nnnnnnnn |= nnnnnnnn << 32; return nnnnnnnn; } static class StatsAccessor { static final int NAME_SLOT_SIZE = 104; static final long HASH_OFFSET = 0; static final long NAMELEN_OFFSET = HASH_OFFSET + Long.BYTES; static final long SUM_OFFSET = NAMELEN_OFFSET + Integer.BYTES; static final long COUNT_OFFSET = SUM_OFFSET + Integer.BYTES; static final long MIN_OFFSET = COUNT_OFFSET + Integer.BYTES; static final long MAX_OFFSET = MIN_OFFSET + Short.BYTES; static final long NAME_OFFSET = MAX_OFFSET + Short.BYTES; static final long SIZEOF = (NAME_OFFSET + NAME_SLOT_SIZE - 1) / 8 * 8 + 8; static final int ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); private final long address; private long slotBase; StatsAccessor(MemorySegment memSeg) { memSeg.fill((byte) 0); this.address = memSeg.address(); } void gotoIndex(int index) { slotBase = address + index * SIZEOF; } long hash() { return UNSAFE.getLong(slotBase + HASH_OFFSET); } int nameLen() { return UNSAFE.getInt(slotBase + NAMELEN_OFFSET); } int sum() { return UNSAFE.getInt(slotBase + SUM_OFFSET); } int count() { return UNSAFE.getInt(slotBase + COUNT_OFFSET); } short min() { return UNSAFE.getShort(slotBase + MIN_OFFSET); } short max() { return UNSAFE.getShort(slotBase + MAX_OFFSET); } long nameAddress() { return slotBase + NAME_OFFSET; } String exportNameString() { final var bytes = new byte[nameLen()]; UNSAFE.copyMemory(null, nameAddress(), bytes, ARRAY_BASE_OFFSET, nameLen()); return new String(bytes, StandardCharsets.UTF_8); } void setHash(long hash) { UNSAFE.putLong(slotBase + HASH_OFFSET, hash); } void setNameLen(int nameLen) { UNSAFE.putInt(slotBase + NAMELEN_OFFSET, nameLen); } void setSum(int sum) { UNSAFE.putInt(slotBase + SUM_OFFSET, sum); } void setCount(int count) { UNSAFE.putInt(slotBase + COUNT_OFFSET, count); } void setMin(short min) { UNSAFE.putShort(slotBase + MIN_OFFSET, min); } void setMax(short max) { UNSAFE.putShort(slotBase + MAX_OFFSET, max); } } private static void mergeSortAndPrint(StationStats[][] results) { var onFirst = true; System.out.print('{'); var cursors = new int[results.length]; var indexOfMin = 0; StationStats curr = null; int exhaustedCount; while (true) { exhaustedCount = 0; StationStats min = null; for (int i = 0; i < cursors.length; i++) { if (cursors[i] == results[i].length) { exhaustedCount++; continue; } StationStats candidate = results[i][cursors[i]]; if (min == null || min.compareTo(candidate) > 0) { indexOfMin = i; min = candidate; } } if (exhaustedCount == cursors.length) { if (!onFirst) { System.out.print(", "); } System.out.print(curr); break; } cursors[indexOfMin]++; if (curr == null) { curr = min; } else if (min.equals(curr)) { curr.sum += min.sum; curr.count += min.count; curr.min = Integer.min(curr.min, min.min); curr.max = Integer.max(curr.max, min.max); } else { if (onFirst) { onFirst = false; } else { System.out.print(", "); } System.out.print(curr); curr = min; } } System.out.println('}'); } }