From 765583e7d89c7cc879d8e67158a228a78d4c2b71 Mon Sep 17 00:00:00 2001 From: zerninv Date: Wed, 17 Jan 2024 17:35:22 +0000 Subject: [PATCH] improve equality check performance, use graal jvm (#454) --- prepare_zerninv.sh | 20 +++ .../onebrc/CalculateAverage_zerninv.java | 129 +++++++++--------- 2 files changed, 86 insertions(+), 63 deletions(-) create mode 100755 prepare_zerninv.sh diff --git a/prepare_zerninv.sh b/prepare_zerninv.sh new file mode 100755 index 0000000..cd3641e --- /dev/null +++ b/prepare_zerninv.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 \ No newline at end of file diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java index 2e7ea4c..42cf6b8 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java @@ -31,39 +31,36 @@ import java.util.TreeMap; public class CalculateAverage_zerninv { private static final String FILE = "./measurements.txt"; - private static final int L3_CACHE_SIZE = 128 * 1024 * 1024; private static final int CORES = Runtime.getRuntime().availableProcessors(); - private static final int CHUNK_SIZE = (L3_CACHE_SIZE - MeasurementContainer.SIZE * MeasurementContainer.ENTRY_SIZE * CORES) / CORES - 1024 * CORES; - - // #.## - private static final int THREE_DIGITS_MASK = 0x2e0000; - // #.# - private static final int TWO_DIGITS_MASK = 0x2e00; - // #.#- - private static final int TWO_NEGATIVE_DIGITS_MASK = 0x2e002d; - private static final int BYTE_MASK = 0xff; - private static final int ZERO = '0'; - - private static final byte DELIMITER = ';'; - private static final byte LINE_SEPARATOR = '\n'; + private static final int CHUNK_SIZE = 1024 * 1024 * 32; private static final Unsafe UNSAFE = initUnsafe(); + private static Unsafe initUnsafe() { + try { + Field unsafe = Unsafe.class.getDeclaredField("theUnsafe"); + unsafe.setAccessible(true); + return (Unsafe) unsafe.get(Unsafe.class); + } + catch (IllegalAccessException | NoSuchFieldException e) { + throw new RuntimeException(e); + } + } + public static void main(String[] args) throws IOException, InterruptedException { try (var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { var fileSize = channel.size(); var minChunkSize = Math.min(fileSize, CHUNK_SIZE); + var segment = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()); var tasks = new TaskThread[CORES]; for (int i = 0; i < tasks.length; i++) { tasks[i] = new TaskThread(new MeasurementContainer(), (int) (fileSize / minChunkSize / CORES + 1)); } - var memorySegment = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()); - var address = memorySegment.address(); - var chunks = splitByChunks(address, address + fileSize, minChunkSize); + var chunks = splitByChunks(segment.address(), segment.address() + fileSize, minChunkSize); for (int i = 0; i < chunks.size() - 1; i++) { - var task = tasks[i % CORES]; + var task = tasks[i % tasks.length]; task.addChunk(chunks.get(i), chunks.get(i + 1)); } @@ -93,23 +90,12 @@ public class CalculateAverage_zerninv { } } - private static Unsafe initUnsafe() { - try { - Field unsafe = Unsafe.class.getDeclaredField("theUnsafe"); - unsafe.setAccessible(true); - return (Unsafe) unsafe.get(Unsafe.class); - } - catch (IllegalAccessException | NoSuchFieldException e) { - throw new RuntimeException(e); - } - } - private static List splitByChunks(long address, long end, long minChunkSize) { List result = new ArrayList<>((int) ((end - address) / minChunkSize + 1)); result.add(address); while (address < end) { address += Math.min(end - address, minChunkSize); - while (address < end && UNSAFE.getByte(address++) != LINE_SEPARATOR) { + while (address < end && UNSAFE.getByte(address++) != '\n') { } result.add(address); } @@ -141,7 +127,7 @@ public class CalculateAverage_zerninv { @Override public String toString() { - return String.format("%.1f/%.1f/%.1f", min / 10d, sum / 10d / count, max / 10d); + return min / 10d + "/" + Math.round(sum / 1d / count) / 10d + "/" + max / 10d; } } @@ -149,55 +135,59 @@ public class CalculateAverage_zerninv { } private static final class MeasurementContainer { - private static final int SIZE = 1024 * 16; + private static final int SIZE = 1 << 17; - private static final int ENTRY_SIZE = 4 + 4 + 1 + 8 + 8 + 2 + 2; + private static final int ENTRY_SIZE = 4 + 4 + 8 + 1 + 8 + 8 + 2 + 2; private static final int COUNT_OFFSET = 0; private static final int HASH_OFFSET = 4; - private static final int SIZE_OFFSET = 8; - private static final int ADDRESS_OFFSET = 9; - private static final int SUM_OFFSET = 17; - private static final int MIN_OFFSET = 25; - private static final int MAX_OFFSET = 27; + private static final int LAST_BYTES_OFFSET = 8; + private static final int SIZE_OFFSET = 16; + private static final int ADDRESS_OFFSET = 17; + private static final int SUM_OFFSET = 25; + private static final int MIN_OFFSET = 33; + private static final int MAX_OFFSET = 35; private final long address; private MeasurementContainer() { address = UNSAFE.allocateMemory(ENTRY_SIZE * SIZE); UNSAFE.setMemory(address, ENTRY_SIZE * SIZE, (byte) 0); - for (long ptr = address; ptr < address + SIZE * ENTRY_SIZE; ptr += ENTRY_SIZE) { - UNSAFE.putShort(ptr + MIN_OFFSET, Short.MAX_VALUE); - UNSAFE.putShort(ptr + MAX_OFFSET, Short.MIN_VALUE); - } } - public void put(long address, byte size, int hash, short value) { + public void put(long address, byte size, int hash, long lastBytes, short value) { int idx = Math.abs(hash % SIZE); long ptr = this.address + idx * ENTRY_SIZE; int count; + boolean fastEqual; while ((count = UNSAFE.getInt(ptr + COUNT_OFFSET)) != 0) { - if (UNSAFE.getInt(ptr + HASH_OFFSET) == hash - && UNSAFE.getByte(ptr + SIZE_OFFSET) == size - && isEqual(UNSAFE.getLong(ptr + ADDRESS_OFFSET), address, size)) { - break; + fastEqual = UNSAFE.getInt(ptr + HASH_OFFSET) == hash && UNSAFE.getLong(ptr + LAST_BYTES_OFFSET) == lastBytes; + if (fastEqual && UNSAFE.getByte(ptr + SIZE_OFFSET) == size && isEqual(UNSAFE.getLong(ptr + ADDRESS_OFFSET), address, size - 8)) { + + UNSAFE.putInt(ptr + COUNT_OFFSET, count + 1); + UNSAFE.putLong(ptr + ADDRESS_OFFSET, address); + UNSAFE.putLong(ptr + SUM_OFFSET, UNSAFE.getLong(ptr + SUM_OFFSET) + value); + if (value < UNSAFE.getShort(ptr + MIN_OFFSET)) { + UNSAFE.putShort(ptr + MIN_OFFSET, value); + } + if (value > UNSAFE.getShort(ptr + MAX_OFFSET)) { + UNSAFE.putShort(ptr + MAX_OFFSET, value); + } + return; } idx = (idx + 1) % SIZE; ptr = this.address + idx * ENTRY_SIZE; } - UNSAFE.putInt(ptr + COUNT_OFFSET, count + 1); + UNSAFE.putInt(ptr + COUNT_OFFSET, 1); UNSAFE.putInt(ptr + HASH_OFFSET, hash); + UNSAFE.putLong(ptr + LAST_BYTES_OFFSET, lastBytes); UNSAFE.putByte(ptr + SIZE_OFFSET, size); UNSAFE.putLong(ptr + ADDRESS_OFFSET, address); - UNSAFE.putLong(ptr + SUM_OFFSET, UNSAFE.getLong(ptr + SUM_OFFSET) + value); - if (value < UNSAFE.getShort(ptr + MIN_OFFSET)) { - UNSAFE.putShort(ptr + MIN_OFFSET, value); - } - if (value > UNSAFE.getShort(ptr + MAX_OFFSET)) { - UNSAFE.putShort(ptr + MAX_OFFSET, value); - } + UNSAFE.putLong(ptr + SUM_OFFSET, value); + UNSAFE.putShort(ptr + MIN_OFFSET, value); + UNSAFE.putShort(ptr + MAX_OFFSET, value); } public List measurements() { @@ -207,21 +197,21 @@ public class CalculateAverage_zerninv { long ptr = this.address + i * ENTRY_SIZE; count = UNSAFE.getInt(ptr + COUNT_OFFSET); if (count != 0) { + var station = createString(UNSAFE.getLong(ptr + ADDRESS_OFFSET), UNSAFE.getByte(ptr + SIZE_OFFSET)); var measurements = new TemperatureAggregation( UNSAFE.getLong(ptr + SUM_OFFSET), count, UNSAFE.getShort(ptr + MIN_OFFSET), UNSAFE.getShort(ptr + MAX_OFFSET)); - var key = createString(UNSAFE.getLong(ptr + ADDRESS_OFFSET), UNSAFE.getByte(ptr + SIZE_OFFSET)); - result.add(new Measurement(key, measurements)); + result.add(new Measurement(station, measurements)); } } return result; } - private static boolean isEqual(long address, long address2, byte size) { - for (int i = 0; i < size; i++) { - if (UNSAFE.getByte(address + i) != UNSAFE.getByte(address2 + i)) { + private boolean isEqual(long address, long address2, int size) { + for (int i = 0; i < size; i += 8) { + if (UNSAFE.getLong(address + i) != UNSAFE.getLong(address2 + i)) { return false; } } @@ -238,6 +228,17 @@ public class CalculateAverage_zerninv { } private static class TaskThread extends Thread { + // #.## + private static final int THREE_DIGITS_MASK = 0x2e0000; + // #.# + private static final int TWO_DIGITS_MASK = 0x2e00; + // #.#- + private static final int TWO_NEGATIVE_DIGITS_MASK = 0x2e002d; + private static final int BYTE_MASK = 0xff; + + private static final int ZERO = '0'; + private static final byte DELIMITER = ';'; + private final MeasurementContainer container; private final List begins; private final List ends; @@ -265,15 +266,17 @@ public class CalculateAverage_zerninv { } private void calcForChunk(long offset, long end) { - long cityOffset; + long cityOffset, lastBytes; int hashCode, temperature, word; byte cityNameSize, b; while (offset < end) { cityOffset = offset; + lastBytes = 0; hashCode = 0; while ((b = UNSAFE.getByte(offset++)) != DELIMITER) { - hashCode = hashCode * 31 + b; + hashCode += hashCode * 31 + b; + lastBytes = (lastBytes << 8) | b; } cityNameSize = (byte) (offset - cityOffset - 1); @@ -297,7 +300,7 @@ public class CalculateAverage_zerninv { temperature = ZERO * 111 - ((word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK)); } offset++; - container.put(cityOffset, cityNameSize, hashCode, (short) temperature); + container.put(cityOffset, cityNameSize, hashCode, lastBytes, (short) temperature); } } }