From a304f80710940fca3b283e9816b92c46986871d0 Mon Sep 17 00:00:00 2001 From: Yevhenii Melnyk Date: Sat, 27 Jan 2024 19:37:19 +0100 Subject: [PATCH] (new submission) melgenek: ~top 15 on 10k. Buffered IO, VarHandles, vectors, custom hashtable (#600) * melgenek: ~top 15 on 10k. Buffered IO, VarHandles, vectors, custom hashtable * Calculate the required heap size dynamically --- calculate_average_melgenek.sh | 37 ++ prepare_melgenek.sh | 19 + .../onebrc/CalculateAverage_melgenek.java | 549 ++++++++++++++++++ 3 files changed, 605 insertions(+) create mode 100755 calculate_average_melgenek.sh create mode 100755 prepare_melgenek.sh create mode 100644 src/main/java/dev/morling/onebrc/CalculateAverage_melgenek.java diff --git a/calculate_average_melgenek.sh b/calculate_average_melgenek.sh new file mode 100755 index 0000000..e0a88a3 --- /dev/null +++ b/calculate_average_melgenek.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +JAVA_OPTS="--enable-preview --add-modules jdk.incubator.vector -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0" +JAVA_OPTS="$JAVA_OPTS -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -XX:+AlwaysPreTouch" +# These flags are mostly copied from the shipilev's branch. They don't really give a predictable benefit, but they don't hurt either. +JAVA_OPTS="$JAVA_OPTS -XX:-TieredCompilation -XX:CICompilerCount=1 -XX:CompileThreshold=2048 -XX:-UseCountedLoopSafepoints -XX:+TrustFinalNonStaticFields" + +if [[ "$(uname -s)" == "Linux" ]]; then + JAVA_OPTS="$JAVA_OPTS -XX:+UseTransparentHugePages" +fi + +# https://stackoverflow.com/a/23378780/7221823 +logicalCpuCount=$([ $(uname) = 'Darwin' ] && + sysctl -n hw.logicalcpu_max || + lscpu -p | egrep -v '^#' | wc -l) +# The required heap is proportional to the number of cores. +# There's roughly 3.5MB heap per thread required for the 10k problem. +requiredMemory=$(echo "(l(15 + 3.5 * $logicalCpuCount)/l(2))" | bc -l) +heapSize=$(echo "scale=0; 2^(($requiredMemory+1)/1)" | bc) + +JAVA_OPTS="$JAVA_OPTS -Xms${heapSize}m -Xmx${heapSize}m" +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_melgenek diff --git a/prepare_melgenek.sh b/prepare_melgenek.sh new file mode 100755 index 0000000..09c53f6 --- /dev/null +++ b/prepare_melgenek.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.2-open 1>&2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_melgenek.java b/src/main/java/dev/morling/onebrc/CalculateAverage_melgenek.java new file mode 100644 index 0000000..1335731 --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_melgenek.java @@ -0,0 +1,549 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.Vector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.TreeMap; +import java.util.concurrent.Executors; + +/** + * The implementation: + * - reads a file with buffered IO + * - uses VarHandles to get longs/ints from a byte array + * - delimiter search is vectorized + * - there is a custom hash function, that provides a low collision rate and short probe distances in hash tables + * - has 2 custom open addressing hash tables: one for strings <=8 bytes in length, and one more for strings of any length + */ +public class CalculateAverage_melgenek { + + private static final VarHandle LONG_VIEW = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.nativeOrder()).withInvokeExactBehavior(); + private static final VarHandle INT_VIEW = MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.nativeOrder()).withInvokeExactBehavior(); + private static final int CORES_COUNT = Runtime.getRuntime().availableProcessors(); + + private static final String FILE = "./measurements.txt"; + + /** + * This is a prime number that gives pretty + * good hash distributions + * on the data in this challenge. + */ + private static final long RANDOM_PRIME = 0x7A646E4D; + private static final int ZERO_CHAR_3_SUM = 100 * '0' + 10 * '0' + '0'; + private static final int ZERO_CHAR_2_SUM = 10 * '0' + '0'; + private static final byte NEWLINE = '\n'; + private static final byte SEMICOLON = ';'; + private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_PREFERRED; + private static final int BYTE_SPECIES_BYTE_SIZE = BYTE_SPECIES.vectorByteSize(); + private static final Vector NEWLINE_VECTOR = BYTE_SPECIES.broadcast(NEWLINE); + private static final Vector SEMICOLON_VECTOR = BYTE_SPECIES.broadcast(SEMICOLON); + private static final int MAX_LINE_LENGTH = 107; // 100 + len(";-11.1\n") = 100+7 + private static final TreeMap RESULT = new TreeMap<>(); + + public static void main(String[] args) throws Throwable { + long totalSize = Files.size(Path.of(FILE)); + try (var executor = Executors.newFixedThreadPool(CORES_COUNT - 1)) { + long chunkSize = Math.max(1, totalSize / CORES_COUNT); + long offset = 0; + int i = 0; + for (; offset < totalSize && i < CORES_COUNT - 1; i++) { + long currentOffset = offset; + long maxOffset = Math.min((i + 1) * chunkSize, totalSize); + executor.submit(() -> processRange(currentOffset, maxOffset)); + offset = (i + 1) * chunkSize - 1; + } + if (offset < totalSize) { + processRange(offset, totalSize); + } + } + System.out.println(RESULT); + } + + private static void processRange(long startOffset, long maxOffset) { + final var table = new CompositeTable(); + try (var file = new BufferedFile(startOffset, maxOffset)) { + processChunk(file, table); + } + catch (Exception e) { + throw new RuntimeException(e); + } + synchronized (RESULT) { + table.addRows(RESULT); + } + } + + private static void processChunk(BufferedFile file, CompositeTable table) { + if (file.offset != 0) { + file.refillBuffer(); + int newlinePosition = findDelimiter(file, 0, NEWLINE_VECTOR, NEWLINE); + file.bufferPosition = newlinePosition + 1; + file.offset += file.bufferPosition; + } + while (file.offset < file.maxOffset) { + file.refillBuffer(); + int bytesProcessed = processOneRow(file, table); + file.offset += bytesProcessed; + } + } + + private static int processOneRow(BufferedFile file, CompositeTable table) { + int stringStart = file.bufferPosition; + int stringEnd = findDelimiter(file, stringStart, SEMICOLON_VECTOR, SEMICOLON); + + file.bufferPosition = stringEnd + 1; + short value = parseValue(file); + + table.add(file.buffer, stringStart, stringEnd, value); + + return file.bufferPosition - stringStart; + } + + private static short parseValue(BufferedFile file) { + byte firstDigit = file.buffer[file.bufferPosition]; + int sign = 1; + if (firstDigit == '-') { + sign = -1; + file.bufferPosition++; + firstDigit = file.buffer[file.bufferPosition]; + } + + byte secondDigit = file.buffer[file.bufferPosition + 1]; + int result; + if (secondDigit == '.') { + result = firstDigit * 10 + file.buffer[file.bufferPosition + 2] - ZERO_CHAR_2_SUM; + file.bufferPosition += 4; + } + else { + result = firstDigit * 100 + secondDigit * 10 + file.buffer[file.bufferPosition + 3] - ZERO_CHAR_3_SUM; + file.bufferPosition += 5; + } + return (short) (result * sign); + } + + /** + * Finds a delimiter in a byte array using vectorized comparisons. + */ + private static int findDelimiter(BufferedFile file, int startPosition, Vector repeatedDelimiter, byte delimiter) { + int position = startPosition; + int vectorLoopBound = startPosition + BYTE_SPECIES.loopBound(file.bufferLimit - startPosition); + for (; position < vectorLoopBound; position += BYTE_SPECIES_BYTE_SIZE) { + var vector = ByteVector.fromArray(BYTE_SPECIES, file.buffer, position); + var comparisonResult = vector.compare(VectorOperators.EQ, repeatedDelimiter); + if (comparisonResult.anyTrue()) { + return position + comparisonResult.firstTrue(); + } + } + + while (file.buffer[position] != delimiter) { + position++; + } + + return position; + } + + private static long keepLastBytes(long value, int numBytesToKeep) { + // Number of bits to shift, so that the mask covers only `numBytesToKeep` least significant bits + int bitShift = (Long.BYTES - numBytesToKeep) * Byte.SIZE; + // Mask with the specified number of the least significant bits set to 1 + long mask = -1L >>> bitShift; + return value & mask; + } + + /** + * The function transforms a string with the length <=8 bytes to a java String. + * The function assumes that the string is 0 terminated. + */ + private static String longToString(long value) { + int strLength = Long.BYTES - Long.numberOfLeadingZeros(value) / Byte.SIZE; + var bytes = new byte[strLength]; + for (int i = 0; i < strLength; i++) { + bytes[i] = (byte) (value >> (i * Byte.SIZE)); + } + return new String(bytes, StandardCharsets.UTF_8); + } + + /** + * Store measurements based on string lengths. + * Stores strings of length <= 8 and other strings separately. + * This table is a simplified implementation of strings hash table in ClickHouse. + * The original parer that describes benefits of the approach is SAHA: A String Adaptive Hash Table for Analytical Databases. + */ + private static final class CompositeTable { + private final LongTable longTable = new LongTable(); + private final RegularTable regularTable = new RegularTable(); + + private void add(byte[] buffer, int stringStart, int stringEnd, short value) { + int stringLength = stringEnd - stringStart; + if (stringLength <= Long.BYTES) { + long str = keepLastBytes((long) LONG_VIEW.get(buffer, stringStart), stringLength); + this.longTable.add(str, value); + } + else { + int hash = calculateHash(buffer, stringStart, stringEnd); + this.regularTable.add(buffer, stringStart, stringLength, hash, value); + } + } + + public void addRows(TreeMap result) { + this.longTable.addRows(result); + this.regularTable.addRows(result); + } + } + + /** + * The hash calculation is inspired by + * QuestDB FastMap + */ + private static int calculateHash(byte[] buffer, int startPosition, int endPosition) { + long hash = 0; + + int position = startPosition; + for (; position + Long.BYTES <= endPosition; position += Long.BYTES) { + long value = (long) LONG_VIEW.get(buffer, position); + hash = hash * RANDOM_PRIME + value; + } + + if (position + Integer.BYTES <= endPosition) { + int value = (int) INT_VIEW.get(buffer, position); + hash = hash * RANDOM_PRIME + value; + position += Integer.BYTES; + } + + for (; position <= endPosition; position++) { + hash = hash * RANDOM_PRIME + buffer[position]; + } + hash = hash * RANDOM_PRIME; + return (int) hash ^ (int) (hash >>> 32); + } + + private static int calculateLongHash(long str) { + long hash = str * RANDOM_PRIME; + return (int) hash ^ (int) (hash >>> 32); + } + + /** + * This tables stores strings of length <= 8 bytes. + * Does not store hashes. + */ + private static final class LongTable { + private static final int TABLE_CAPACITY = 32768; + private static final int TABLE_CAPACITY_MASK = TABLE_CAPACITY - 1; + /** + * The buckets use 3 longs to store strings and measurements: + * long 1) station name + * long 2) sum of measurements + * long 3) count (int) | min (short) | max (short) <-- packed into one long + */ + private final long[] buckets = new long[TABLE_CAPACITY * 3]; + + int keysCount = 0; + + public void add(long str, short value) { + int hash = calculateLongHash(str); + int bucketIdx = hash & TABLE_CAPACITY_MASK; + + long bucketStr = buckets[bucketIdx * 3]; + if (bucketStr == str) { + updateBucket(bucketIdx, value); + } + else if (bucketStr == 0L) { + createBucket(bucketIdx, str, value); + keysCount++; + } + else { + addWithProbing(str, value, (bucketIdx + 1) & TABLE_CAPACITY_MASK); + } + } + + private void addWithProbing(long str, short value, int bucketIdx) { + int distance = 1; + while (true) { + long bucketStr = buckets[bucketIdx * 3]; + if (bucketStr == str) { + updateBucket(bucketIdx, value); + break; + } + else if (bucketStr == 0L) { + createBucket(bucketIdx, str, value); + keysCount++; + break; + } + else { + distance++; + // A new bucket index is calculated based on quadratic probing https://thenumb.at/Hashtables/#quadratic-probing + // Quadratic probing decreases the number of collisions and max probing distance. + // Linear: + // - capacity 16k, 28.6M collisions, 14-17 max distance + // - capacity 32k, 9.5M collisions, 5-7 max distance + // Quadratic: + // - capacity 16k 25M collisions, 8-10 max distance + // - capacity 32k, 9.3M collisions, 4-7 max distance + bucketIdx = (bucketIdx + distance) & TABLE_CAPACITY_MASK; + } + } + } + + public void addRows(TreeMap result) { + for (int bucketIdx = 0; bucketIdx < TABLE_CAPACITY; bucketIdx++) { + int bucketOffset = bucketIdx * 3; + long str = buckets[bucketOffset]; + if (str != 0L) { + long sum = buckets[bucketOffset + 1]; + long countMinMax = buckets[bucketOffset + 2]; + int count = (int) ((countMinMax >> 32)); + short min = (short) ((countMinMax >> 16) & 0xFFFF); + short max = (short) (countMinMax & 0xFFFF); + + result.compute(longToString(str), (k, resultRow) -> { + if (resultRow == null) { + return new ResultRow(sum, count, min, max); + } + else { + resultRow.add(sum, count, min, max); + return resultRow; + } + }); + } + } + } + + private void createBucket(int bucketIdx, long str, short value) { + int offset = bucketIdx * 3; + buckets[offset] = str; + buckets[offset + 1] = value; + buckets[offset + 2] = (1L << 32) | ((long) (value & 0xFFFF) << 16) | (long) (value & 0xFFFF); + } + + private void updateBucket(int bucketIdx, short value) { + int offset = bucketIdx * 3; + long sum = buckets[offset + 1]; + buckets[offset + 1] = sum + value; + + long countMinMax = buckets[offset + 2]; + int count = (int) ((countMinMax >> 32)); + short min = (short) ((countMinMax >> 16) & 0xFFFF); + short max = (short) (countMinMax & 0xFFFF); + if (value < min) { + min = value; + } + if (value > max) { + max = value; + } + buckets[offset + 2] = ((long) (count + 1) << 32) | ((long) (min & 0xFFFF) << 16) | (long) (max & 0xFFFF); + } + } + + /** + * An open addressing hash table that stores strings as byte arrays. + * Stores hashes. + */ + private static final class RegularTable { + private static final int TABLE_CAPACITY = 16384; + private static final int TABLE_CAPACITY_MASK = TABLE_CAPACITY - 1; + private final Bucket[] buckets = new Bucket[TABLE_CAPACITY]; + + int keysCount = 0; + + public void add(byte[] data, int start, int stringLength, int hash, short value) { + int bucketIdx = hash & TABLE_CAPACITY_MASK; + + var bucket = buckets[bucketIdx]; + if (bucket == null) { + buckets[bucketIdx] = new Bucket(data, start, stringLength, hash, value); + keysCount++; + } + else if (hash == bucket.hash && bucket.isEqual(data, start, stringLength)) { + bucket.update(value); + } + else { + addWithProbing(data, start, stringLength, hash, value, (bucketIdx + 1) & TABLE_CAPACITY_MASK); + } + } + + private void addWithProbing(byte[] data, int start, int stringLength, int hash, short value, int bucketIdx) { + int distance = 1; + while (true) { + var bucket = buckets[bucketIdx]; + if (bucket == null) { + buckets[bucketIdx] = new Bucket(data, start, stringLength, hash, value); + keysCount++; + break; + } + else if (hash == bucket.hash && bucket.isEqual(data, start, stringLength)) { + bucket.update(value); + break; + } + else { + distance++; + bucketIdx = (bucketIdx + distance) & TABLE_CAPACITY_MASK; + } + } + } + + public void addRows(TreeMap result) { + for (var bucket : buckets) { + if (bucket != null) { + result.compute(new String(bucket.str, StandardCharsets.UTF_8), (k, resultRow) -> { + if (resultRow == null) { + return new ResultRow(bucket.sum, bucket.count, bucket.min, bucket.max); + } + else { + resultRow.add(bucket.sum, bucket.count, bucket.min, bucket.max); + return resultRow; + } + }); + } + } + } + + private static final class Bucket { + int hash; + byte[] str; + long sum; + int count; + short max = Short.MIN_VALUE; + short min = Short.MAX_VALUE; + + Bucket(byte[] data, int start, int stringLength, int hash, short value) { + this.str = new byte[stringLength]; + System.arraycopy(data, start, this.str, 0, stringLength); + this.hash = hash; + update(value); + } + + public void update(short value) { + this.sum += value; + this.count++; + if (max < value) + max = value; + if (min > value) + min = value; + } + + public boolean isEqual(byte[] data, int start, int length) { + if (str.length != length) + return false; + int i = 0; + for (; i + Long.BYTES < str.length; i += Long.BYTES) { + long value1 = (long) LONG_VIEW.get(str, i); + long value2 = (long) LONG_VIEW.get(data, start + i); + if (value1 != value2) + return false; + } + if (i + Integer.BYTES < str.length) { + int value1 = (int) INT_VIEW.get(str, i); + int value2 = (int) INT_VIEW.get(data, start + i); + if (value1 != value2) + return false; + i += Integer.BYTES; + } + for (; i < str.length; i++) { + if (data[start + i] != str[i]) + return false; + } + return true; + } + } + } + + private static class ResultRow { + long sum; + int count; + short min; + short max; + + public ResultRow(long sum, int count, short min, short max) { + this.sum = sum; + this.count = count; + this.min = min; + this.max = max; + } + + public void add(long anotherSum, int anotherCount, short anotherMin, short anotherMax) { + sum += anotherSum; + count += anotherCount; + if (max < anotherMax) + max = anotherMax; + if (min > anotherMin) + min = anotherMin; + } + + public String toString() { + return Math.round((double) min) / 10.0 + "/" + + Math.round((double) sum / count) / 10.0 + "/" + + Math.round((double) max) / 10.0; + } + } + + /** + * A utility class that uses the RandomAccessFile to read at offset. + * Keeps the in-memory buffer, as well as current offsets in the buffer and the file. + */ + private static final class BufferedFile implements AutoCloseable { + private static final int BUFFER_SIZE = 512 * 1024; + private final byte[] buffer = new byte[BUFFER_SIZE]; + private int bufferLimit = 0; + private int bufferPosition = 0; + private final long maxOffset; + private final RandomAccessFile file; + private long offset; + + private BufferedFile(long startOffset, long maxOffset) throws FileNotFoundException { + this.offset = startOffset; + this.maxOffset = maxOffset; + this.file = new RandomAccessFile(FILE, "r"); + } + + private void refillBuffer() { + int remainingBytes = bufferLimit - bufferPosition; + if (remainingBytes < MAX_LINE_LENGTH) { + bufferPosition = 0; + int bytesRead; + try { + file.seek(offset); + bytesRead = file.read(buffer, 0, BUFFER_SIZE); + } + catch (IOException e) { + throw new RuntimeException(e); + } + if (bytesRead > 0) { + bufferLimit = bytesRead; + } + else { + bufferLimit = 0; + } + } + } + + @Override + public void close() throws Exception { + file.close(); + } + } + +}