From dbdd89a84779761ca092e5aaeb6f6e92394a422d Mon Sep 17 00:00:00 2001 From: Jaromir Hamala Date: Mon, 15 Jan 2024 18:55:22 +0100 Subject: [PATCH] jerrinot's initial submission (#424) * initial version let's exploit that superscalar beauty! * give credits where credits is due also: added ideas I don't want to forget --- calculate_average_jerrinot.sh | 21 + prepare_jerrinot.sh | 19 + .../onebrc/CalculateAverage_jerrinot.java | 482 ++++++++++++++++++ 3 files changed, 522 insertions(+) create mode 100755 calculate_average_jerrinot.sh create mode 100755 prepare_jerrinot.sh create mode 100644 src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java diff --git a/calculate_average_jerrinot.sh b/calculate_average_jerrinot.sh new file mode 100755 index 0000000..1bbf680 --- /dev/null +++ b/calculate_average_jerrinot.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_mtopolnik::recordMeasurementAndAdvanceCursor" +# -XX:InlineSmallCode=10000 -XX:-TieredCompilation -XX:CICompilerCount=2 -XX:CompileThreshold=1000\ +java --enable-preview \ + --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jerrinot diff --git a/prepare_jerrinot.sh b/prepare_jerrinot.sh new file mode 100755 index 0000000..f83a3ff --- /dev/null +++ b/prepare_jerrinot.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java new file mode 100644 index 0000000..6fb89bb --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java @@ -0,0 +1,482 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import sun.misc.Unsafe; + +import java.io.File; +import java.io.RandomAccessFile; +import java.lang.foreign.Arena; +import java.lang.reflect.Field; +import java.nio.channels.FileChannel.MapMode; +import java.util.Map; +import java.util.TreeMap; + +public class CalculateAverage_jerrinot { + private static final Unsafe UNSAFE = unsafe(); + private static final String MEASUREMENTS_TXT = "measurements.txt"; + // todo: with hyper-threading enable we would be better of with availableProcessors / 2; + // todo: validate the testing env. params. + private static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors(); + private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL; + + private static Unsafe unsafe() { + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + return (Unsafe) theUnsafe.get(Unsafe.class); + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) throws Exception { + calculate(); + } + + static void calculate() throws Exception { + final File file = new File(MEASUREMENTS_TXT); + final long length = file.length(); + // final int chunkCount = Runtime.getRuntime().availableProcessors(); + int chunkPerThread = 4; + final int chunkCount = THREAD_COUNT * chunkPerThread; + final var chunkStartOffsets = new long[chunkCount + 1]; + try (var raf = new RandomAccessFile(file, "r")) { + // credit - chunking code: mtopolnik + final var inputBase = raf.getChannel().map(MapMode.READ_ONLY, 0, length, Arena.global()).address(); + for (int i = 1; i < chunkStartOffsets.length - 1; i++) { + var start = length * i / (chunkStartOffsets.length - 1); + raf.seek(start); + while (raf.read() != (byte) '\n') { + } + start = raf.getFilePointer(); + chunkStartOffsets[i] = start + inputBase; + } + chunkStartOffsets[0] = inputBase; + chunkStartOffsets[chunkCount] = inputBase + length; + + Processor[] processors = new Processor[THREAD_COUNT]; + Thread[] threads = new Thread[THREAD_COUNT]; + + for (int i = 0; i < THREAD_COUNT; i++) { + long startA = chunkStartOffsets[i * chunkPerThread]; + long endA = chunkStartOffsets[i * chunkPerThread + 1]; + long startB = chunkStartOffsets[i * chunkPerThread + 1]; + long endB = chunkStartOffsets[i * chunkPerThread + 2]; + long startC = chunkStartOffsets[i * chunkPerThread + 2]; + long endC = chunkStartOffsets[i * chunkPerThread + 3]; + long startD = chunkStartOffsets[i * chunkPerThread + 3]; + long endD = chunkStartOffsets[i * chunkPerThread + 4]; + + Processor processor = new Processor(startA, endA, startB, endB, startC, endC, startD, endD); + processors[i] = processor; + Thread thread = new Thread(processor); + threads[i] = thread; + thread.start(); + } + + var accumulator = new TreeMap(); + for (int i = 0; i < THREAD_COUNT; i++) { + Thread t = threads[i]; + t.join(); + processors[i].accumulateStatus(accumulator); + } + + var sb = new StringBuilder(); + boolean first = true; + for (Map.Entry statsEntry : accumulator.entrySet()) { + if (first) { + sb.append("{"); + first = false; + } + else { + sb.append(", "); + } + var value = statsEntry.getValue(); + var name = statsEntry.getKey(); + int min = value.min; + int max = value.max; + int count = value.count; + long sum2 = value.sum; + sb.append(String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum2 / count) / 10.0, max / 10.0)); + } + System.out.print(sb); + System.out.println('}'); + } + } + + public static int ceilPow2(int i) { + i--; + i |= i >> 1; + i |= i >> 2; + i |= i >> 4; + i |= i >> 8; + i |= i >> 16; + return i + 1; + } + + private static class Processor implements Runnable { + private static final int MAP_SLOT_COUNT = ceilPow2(10000); + private static final int STATION_MAX_NAME_BYTES = 104; + + private static final long COUNT_OFFSET = 0; + private static final long MIN_OFFSET = 4; + private static final long MAX_OFFSET = 8; + private static final long SUM_OFFSET = 12; + private static final long LEN_OFFSET = 20; + private static final long NAME_OFFSET = 24; + + private static final int MAP_ENTRY_SIZE_BYTES = +Integer.BYTES // count // 0 + + Integer.BYTES // min // +4 + + Integer.BYTES // max // +8 + + Long.BYTES // sum // +12 + + Integer.BYTES // station name len // +20 + + STATION_MAX_NAME_BYTES; // +24 + + private static final int MAP_SIZE_BYTES = MAP_SLOT_COUNT * MAP_ENTRY_SIZE_BYTES; + private static final long MAP_MASK = MAP_SLOT_COUNT - 1; + + // todo: some fields could probably be converted to locals + + private final long map; + + private long cursorA; + private long endA; + private long cursorB; + private long endB; + private long cursorC; + private long endC; + private long cursorD; + private long endD; + private long maskA; + private long maskB; + private long maskC; + private long maskD; + + // credit: merykitty + private long parseAndStoreTemperature(long startCursor, long baseEntryPtr) { + long word = UNSAFE.getLong(startCursor); + final long negateda = ~word; + final int dotPos = Long.numberOfTrailingZeros(negateda & 0x10101000); + final long signed = (negateda << 59) >> 63; + final long removeSignMask = ~(signed & 0xFF); + final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L; + final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; + final int temperature = (int) ((absValue ^ signed) - signed); + + long countPtr = baseEntryPtr + COUNT_OFFSET; + long minPtr = baseEntryPtr + MIN_OFFSET; + long maxPtr = baseEntryPtr + MAX_OFFSET; + long sumPtr = baseEntryPtr + SUM_OFFSET; + + int min = UNSAFE.getInt(minPtr); + int max = UNSAFE.getInt(maxPtr); + long sum = UNSAFE.getLong(sumPtr); + // try if min/max intrinsics are paying off + // maybe braching is better? the branch is becoming more predictable with + // each new sample. + max = Math.max(max, temperature); + min = Math.min(min, temperature); + sum += temperature; + UNSAFE.putInt(countPtr, UNSAFE.getInt(countPtr) + 1); + UNSAFE.putInt(minPtr, min); + UNSAFE.putInt(maxPtr, max); + UNSAFE.putLong(sumPtr, sum); + return startCursor + (dotPos / 8) + 3; + } + + private static long getDelimiterMask(final long word) { + // credit royvanrijn + final long match = word ^ SEPARATOR_PATTERN; + return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L); + } + + // todo: immutability cost us in allocations, but that's probably peanuts in the grand scheme of things. still worth checking + // maybe JVM trusting Final in Records offsets it ..a test is needed + record StationStats(int min, int max, int count, long sum) { + } + + void accumulateStatus(TreeMap accumulator) { + for (long baseAddress = map; baseAddress < map + MAP_SIZE_BYTES; baseAddress += MAP_ENTRY_SIZE_BYTES) { + long len = UNSAFE.getInt(baseAddress + LEN_OFFSET); + if (len == 0) { + continue; + } + byte[] nameArr = new byte[(int) len]; + long baseNameAddr = baseAddress + NAME_OFFSET; + for (int i = 0; i < len; i++) { + nameArr[i] = UNSAFE.getByte(baseNameAddr + i); + } + String name = new String(nameArr); + int min = UNSAFE.getInt(baseAddress + MIN_OFFSET); + int max = UNSAFE.getInt(baseAddress + MAX_OFFSET); + int count = UNSAFE.getInt(baseAddress + COUNT_OFFSET); + long sum = UNSAFE.getLong(baseAddress + SUM_OFFSET); + + // todo: lambdas bootstrap probably cost us + accumulator.compute(name, (_, v) -> { + if (v == null) { + return new StationStats(min, max, count, sum); + } + return new StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum); + }); + } + } + + Processor(long startA, long endA, long startB, long endB, long startC, long endC, long startD, long endD) { + this.cursorA = startA; + this.cursorB = startB; + this.cursorC = startC; + this.cursorD = startD; + this.endA = endA; + this.endB = endB; + this.endC = endC; + this.endD = endD; + this.map = UNSAFE.allocateMemory(MAP_SIZE_BYTES); + + int i; + for (i = 0; i < MAP_SIZE_BYTES; i += 8) { + UNSAFE.putLong(map + i, 0); + } + for (i = i - 8; i < MAP_SIZE_BYTES; i++) { + UNSAFE.putByte(map + i, (byte) 0); + } + } + + private void doTail() { + // todo: we would be probably better of without all that code dup. ("compilers hates him!") + // System.out.println("done ILP"); + while (cursorA < endA) { + long startA = cursorA; + long delimiterWordA = UNSAFE.getLong(cursorA); + long hashA = 0; + maskA = getDelimiterMask(delimiterWordA); + while (maskA == 0) { + hashA ^= delimiterWordA; + cursorA += 8; + delimiterWordA = UNSAFE.getLong(cursorA); + maskA = getDelimiterMask(delimiterWordA); + } + final int delimiterByteA = Long.numberOfTrailingZeros(maskA); + final long semicolonA = cursorA + (delimiterByteA >> 3); + final long maskedWordA = delimiterWordA & ((maskA >>> 7) - 1); + hashA ^= maskedWordA; + int intHashA = (int) (hashA ^ (hashA >> 32)); + intHashA = intHashA ^ (intHashA >> 17); + + long baseEntryPtrA = getOrCreateEntryBaseOffset(semicolonA, startA, intHashA, maskedWordA); + cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA); + } + // System.out.println("done A"); + while (cursorB < endB) { + long startB = cursorB; + long delimiterWordB = UNSAFE.getLong(cursorB); + long hashB = 0; + maskB = getDelimiterMask(delimiterWordB); + while (maskB == 0) { + hashB ^= delimiterWordB; + cursorB += 8; + delimiterWordB = UNSAFE.getLong(cursorB); + maskB = getDelimiterMask(delimiterWordB); + } + final int delimiterByteB = Long.numberOfTrailingZeros(maskB); + final long semicolonB = cursorB + (delimiterByteB >> 3); + final long maskedWordB = delimiterWordB & ((maskB >>> 7) - 1); + hashB ^= maskedWordB; + int intHashB = (int) (hashB ^ (hashB >> 32)); + intHashB = intHashB ^ (intHashB >> 17); + + long baseEntryPtrB = getOrCreateEntryBaseOffset(semicolonB, startB, intHashB, maskedWordB); + cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB); + } + // System.out.println("done B"); + while (cursorC < endC) { + long startC = cursorC; + long delimiterWordC = UNSAFE.getLong(cursorC); + long hashC = 0; + maskC = getDelimiterMask(delimiterWordC); + while (maskC == 0) { + hashC ^= delimiterWordC; + cursorC += 8; + delimiterWordC = UNSAFE.getLong(cursorC); + maskC = getDelimiterMask(delimiterWordC); + } + final int delimiterByteC = Long.numberOfTrailingZeros(maskC); + final long semicolonC = cursorC + (delimiterByteC >> 3); + final long maskedWordC = delimiterWordC & ((maskC >>> 7) - 1); + hashC ^= maskedWordC; + int intHashC = (int) (hashC ^ (hashC >> 32)); + intHashC = intHashC ^ (intHashC >> 17); + + long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, intHashC, maskedWordC); + cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC); + } + // System.out.println("done C"); + while (cursorD < endD) { + long startD = cursorD; + long delimiterWordD = UNSAFE.getLong(cursorD); + long hashD = 0; + maskD = getDelimiterMask(delimiterWordD); + while (maskD == 0) { + hashD ^= delimiterWordD; + cursorD += 8; + delimiterWordD = UNSAFE.getLong(cursorD); + maskD = getDelimiterMask(delimiterWordD); + } + final int delimiterByteD = Long.numberOfTrailingZeros(maskD); + final long semicolonD = cursorD + (delimiterByteD >> 3); + final long maskedWordD = delimiterWordD & ((maskD >>> 7) - 1); + hashD ^= maskedWordD; + int intHashD = (int) (hashD ^ (hashD >> 32)); + intHashD = intHashD ^ (intHashD >> 17); + + long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, intHashD, maskedWordD); + cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD); + } + // System.out.println("done D"); + } + + @Override + public void run() { + while (cursorA < endA && cursorB < endB && cursorC < endC && cursorD < endD) { + // todo: experiment with different inter-leaving + long startA = cursorA; + long startB = cursorB; + long startC = cursorC; + long startD = cursorD; + + long delimiterWordA = UNSAFE.getLong(cursorA); + long delimiterWordB = UNSAFE.getLong(cursorB); + long delimiterWordC = UNSAFE.getLong(cursorC); + long delimiterWordD = UNSAFE.getLong(cursorD); + + long hashA = 0; + long hashB = 0; + long hashC = 0; + long hashD = 0; + + // credits for the hashing idea: royvanrijn + maskA = getDelimiterMask(delimiterWordA); + while (maskA == 0) { + hashA ^= delimiterWordA; + cursorA += 8; + delimiterWordA = UNSAFE.getLong(cursorA); + maskA = getDelimiterMask(delimiterWordA); + } + final int delimiterByteA = Long.numberOfTrailingZeros(maskA); + final long semicolonA = cursorA + (delimiterByteA >> 3); + final long maskedWordA = delimiterWordA & ((maskA >>> 7) - 1); + hashA ^= maskedWordA; + int intHashA = (int) (hashA ^ (hashA >> 32)); + intHashA = intHashA ^ (intHashA >> 17); + + maskB = getDelimiterMask(delimiterWordB); + while (maskB == 0) { + hashB ^= delimiterWordB; + cursorB += 8; + delimiterWordB = UNSAFE.getLong(cursorB); + maskB = getDelimiterMask(delimiterWordB); + } + final int delimiterByteB = Long.numberOfTrailingZeros(maskB); + final long semicolonB = cursorB + (delimiterByteB >> 3); + final long maskedWordB = delimiterWordB & ((maskB >>> 7) - 1); + hashB ^= maskedWordB; + int intHashB = (int) (hashB ^ (hashB >> 32)); + intHashB = intHashB ^ (intHashB >> 17); + + maskC = getDelimiterMask(delimiterWordC); + while (maskC == 0) { + hashC ^= delimiterWordC; + cursorC += 8; + delimiterWordC = UNSAFE.getLong(cursorC); + maskC = getDelimiterMask(delimiterWordC); + } + final int delimiterByteC = Long.numberOfTrailingZeros(maskC); + final long semicolonC = cursorC + (delimiterByteC >> 3); + final long maskedWordC = delimiterWordC & ((maskC >>> 7) - 1); + hashC ^= maskedWordC; + int intHashC = (int) (hashC ^ (hashC >> 32)); + intHashC = intHashC ^ (intHashC >> 17); + + maskD = getDelimiterMask(delimiterWordD); + while (maskD == 0) { + hashD ^= delimiterWordD; + cursorD += 8; + delimiterWordD = UNSAFE.getLong(cursorD); + maskD = getDelimiterMask(delimiterWordD); + } + final int delimiterByteD = Long.numberOfTrailingZeros(maskD); + final long semicolonD = cursorD + (delimiterByteD >> 3); + final long maskedWordD = delimiterWordD & ((maskD >>> 7) - 1); + hashD ^= maskedWordD; + int intHashD = (int) (hashD ^ (hashD >> 32)); + intHashD = intHashD ^ (intHashD >> 17); + + long baseEntryPtrA = getOrCreateEntryBaseOffset(semicolonA, startA, intHashA, maskedWordA); + long baseEntryPtrB = getOrCreateEntryBaseOffset(semicolonB, startB, intHashB, maskedWordB); + long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, intHashC, maskedWordC); + long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, intHashD, maskedWordD); + + cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA); + cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB); + cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC); + cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD); + } + doTail(); + } + + private long getOrCreateEntryBaseOffset(long semicolonA, long startA, int intHashA, long maskedWordA) { + int lenA = (int) (semicolonA - startA); + long mapIndexA = intHashA & MAP_MASK; + for (;;) { + long basePtr = mapIndexA * MAP_ENTRY_SIZE_BYTES + map; + long lenPtr = basePtr + LEN_OFFSET; + int len = UNSAFE.getInt(lenPtr); + if (len == 0) { + // todo: uncommon branch maybe? + // empty slot + UNSAFE.copyMemory(semicolonA - lenA, basePtr + NAME_OFFSET, lenA); + UNSAFE.putInt(lenPtr, lenA); + UNSAFE.putInt(basePtr + MAX_OFFSET, Integer.MIN_VALUE); + UNSAFE.putInt(basePtr + MIN_OFFSET, Integer.MAX_VALUE); + return basePtr; + } + if (len == lenA) { + boolean match = true; + long namePtr = basePtr + NAME_OFFSET; + int fullLen = (len >> 3) << 3; + long offset; + // todo: this is worth exploring further. + // @mtopolnik has an interesting algo with 2 unconditioned long loads: this is sufficient + // for majority of names. so we would be left with just a single branch which is almost never taken? + for (offset = 0; offset < fullLen; offset += 8) { + match &= (UNSAFE.getLong(startA + offset) == UNSAFE.getLong(namePtr + offset)); + } + + long maskedWordInMap = UNSAFE.getLong(namePtr + offset); + match &= (maskedWordInMap == maskedWordA); + + if (match) { + return basePtr; + } + } + mapIndexA = ++mapIndexA & MAP_MASK; + } + } + } + +}