From b20e7365e72c092f1800ea814e85d51f9bf53917 Mon Sep 17 00:00:00 2001 From: Dmitry Bufistov <112496477+dmitry-midokura@users.noreply.github.com> Date: Thu, 25 Jan 2024 23:09:22 +0100 Subject: [PATCH] Second submission to keep a bit of dignity (#581) * Dmitry challenge * Dmitry submit 2. Use MemorySegment of FileChannle and Unsafe to read bytes from disk. 4 seconds speedup in local test from 20s to 16s. --- calculate_average_dmitry-midokura.sh | 1 + .../onebrc/CalculateAverage_bufistov.java | 406 +++++++++--------- 2 files changed, 195 insertions(+), 212 deletions(-) diff --git a/calculate_average_dmitry-midokura.sh b/calculate_average_dmitry-midokura.sh index e4d1366..1bb529b 100755 --- a/calculate_average_dmitry-midokura.sh +++ b/calculate_average_dmitry-midokura.sh @@ -17,4 +17,5 @@ #JAVA_OPTS="-verbose:gc" +JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_bufistov $1 $2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java b/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java index db60403..178a6e1 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java @@ -15,11 +15,17 @@ */ package dev.morling.onebrc; +import sun.misc.Unsafe; + import static java.lang.Math.toIntExact; +import java.lang.foreign.Arena; +import java.lang.reflect.Field; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; +import java.nio.file.Paths; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; @@ -32,66 +38,6 @@ import java.io.FileInputStream; import java.io.IOException; import java.util.concurrent.Future; -class ResultRow { - byte[] station; - - String stationString; - long min, max, count, suma; - - ResultRow() { - } - - ResultRow(byte[] station, long value) { - this.station = new byte[station.length]; - System.arraycopy(station, 0, this.station, 0, station.length); - this.min = value; - this.max = value; - this.count = 1; - this.suma = value; - } - - ResultRow(long value) { - this.min = value; - this.max = value; - this.count = 1; - this.suma = value; - } - - void setStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition) { - this.station = new byte[endPosition - startPosition]; - byteBuffer.slice(startPosition, station.length).get(this.station, 0, station.length); - } - - public String toString() { - stationString = new String(station, StandardCharsets.UTF_8); - return stationString + "=" + round(min / 10.0) + "/" + round(suma / 10.0 / count) + "/" + round(max / 10.0); - } - - private double round(double value) { - return Math.round(value * 10.0) / 10.0; - } - - ResultRow update(long newValue) { - this.count += 1; - this.suma += newValue; - if (newValue < this.min) { - this.min = newValue; - } - else if (newValue > this.max) { - this.max = newValue; - } - return this; - } - - ResultRow merge(ResultRow another) { - this.count += another.count; - this.suma += another.suma; - this.min = Math.min(this.min, another.min); - this.max = Math.max(this.max, another.max); - return this; - } -} - class ByteArrayWrapper { private final byte[] data; @@ -110,100 +56,176 @@ class ByteArrayWrapper { } } -class OpenHash { - ResultRow[] data; - int dataSizeMask; +public class CalculateAverage_bufistov { - // ResultRow metrics = new ResultRow(); + static class ResultRow { + byte[] station; - public OpenHash(int capacityPow2) { - assert capacityPow2 <= 20; - int dataSize = 1 << capacityPow2; - dataSizeMask = dataSize - 1; - data = new ResultRow[dataSize]; - } + String stationString; + long min, max, count, suma; - int hashByteArray(byte[] array) { - int result = 0; - long mask = 0; - for (int i = 0; i < array.length; ++i, mask = ((mask + 1) & 3)) { - result += array[i] << mask; + ResultRow() { } - return result & dataSizeMask; - } - void merge(byte[] station, long value, int hashValue) { - while (data[hashValue] != null && !Arrays.equals(station, data[hashValue].station)) { - hashValue += 1; - hashValue &= dataSizeMask; + ResultRow(byte[] station, long value) { + this.station = new byte[station.length]; + System.arraycopy(station, 0, this.station, 0, station.length); + this.min = value; + this.max = value; + this.count = 1; + this.suma = value; } - if (data[hashValue] == null) { - data[hashValue] = new ResultRow(station, value); - } - else { - data[hashValue].update(value); - } - // metrics.update(delta); - } - void merge(byte[] station, long value) { - merge(station, value, hashByteArray(station)); - } + ResultRow(long value) { + this.min = value; + this.max = value; + this.count = 1; + this.suma = value; + } - void merge(MappedByteBuffer byteBuffer, final int startPosition, final int endPosition, int hashValue, final long value) { - while (data[hashValue] != null && !equalsToStation(byteBuffer, startPosition, endPosition, data[hashValue].station)) { - hashValue += 1; - hashValue &= dataSizeMask; - } - if (data[hashValue] == null) { - data[hashValue] = new ResultRow(value); - data[hashValue].setStation(byteBuffer, startPosition, endPosition); - } - else { - data[hashValue].update(value); - } - } - - boolean equalsToStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition, byte[] station) { - if (endPosition - startPosition != station.length) { - return false; - } - for (int i = 0; i < station.length; ++i, ++startPosition) { - if (byteBuffer.get(startPosition) != station[i]) - return false; - } - return true; - } - - HashMap toJavaHashMap() { - HashMap result = new HashMap<>(20000); - for (int i = 0; i < data.length; ++i) { - if (data[i] != null) { - var key = new ByteArrayWrapper(data[i].station); - result.put(key, data[i]); + void setStation(long startPosition, long endPosition) { + this.station = new byte[(int) (endPosition - startPosition)]; + for (int i = 0; i < this.station.length; ++i) { + this.station[i] = UNSAFE.getByte(startPosition + i); } } - return result; - } -} -public class CalculateAverage_bufistov { + public String toString() { + stationString = new String(station, StandardCharsets.UTF_8); + return stationString + "=" + round(min / 10.0) + "/" + round(suma / 10.0 / count) + "/" + round(max / 10.0); + } + + private double round(double value) { + return Math.round(value * 10.0) / 10.0; + } + + void update(long newValue) { + this.count += 1; + this.suma += newValue; + if (newValue < this.min) { + this.min = newValue; + } + else if (newValue > this.max) { + this.max = newValue; + } + } + + ResultRow merge(ResultRow another) { + this.count += another.count; + this.suma += another.suma; + this.min = Math.min(this.min, another.min); + this.max = Math.max(this.max, another.max); + return this; + } + } + + static class OpenHash { + ResultRow[] data; + int dataSizeMask; + + // ResultRow metrics = new ResultRow(); + + public OpenHash(int capacityPow2) { + assert capacityPow2 <= 20; + int dataSize = 1 << capacityPow2; + dataSizeMask = dataSize - 1; + data = new ResultRow[dataSize]; + } + + int hashByteArray(byte[] array) { + int result = 0; + long mask = 0; + for (int i = 0; i < array.length; ++i, mask = ((mask + 1) & 3)) { + result += array[i] << mask; + } + return result & dataSizeMask; + } + + void merge(byte[] station, long value, int hashValue) { + while (data[hashValue] != null && !Arrays.equals(station, data[hashValue].station)) { + hashValue += 1; + hashValue &= dataSizeMask; + } + if (data[hashValue] == null) { + data[hashValue] = new ResultRow(station, value); + } + else { + data[hashValue].update(value); + } + // metrics.update(delta); + } + + void merge(byte[] station, long value) { + merge(station, value, hashByteArray(station)); + } + + void merge(final long startPosition, long endPosition, int hashValue, long value) { + while (data[hashValue] != null && !equalsToStation(startPosition, endPosition, data[hashValue].station)) { + hashValue += 1; + hashValue &= dataSizeMask; + } + if (data[hashValue] == null) { + data[hashValue] = new ResultRow(value); + data[hashValue].setStation(startPosition, endPosition); + } + else { + data[hashValue].update(value); + } + } + + boolean equalsToStation(long startPosition, long endPosition, byte[] station) { + if (endPosition - startPosition != station.length) { + return false; + } + for (int i = 0; i < station.length; ++i, ++startPosition) { + if (UNSAFE.getByte(startPosition) != station[i]) + return false; + } + return true; + } + + HashMap toJavaHashMap() { + HashMap result = new HashMap<>(20000); + for (int i = 0; i < data.length; ++i) { + if (data[i] != null) { + var key = new ByteArrayWrapper(data[i].station); + result.put(key, data[i]); + } + } + return result; + } + } + + static final Unsafe UNSAFE; + + static { + try { + Field unsafe = Unsafe.class.getDeclaredField("theUnsafe"); + unsafe.setAccessible(true); + UNSAFE = (Unsafe) unsafe.get(Unsafe.class); + } + catch (Throwable e) { + throw new RuntimeException(e); + } + } static final long LINE_SEPARATOR = '\n'; public static class FileRead implements Callable> { private final FileChannel fileChannel; + private long currentLocation; - private int bytesToRead; + private long bytesToRead; - private final int hashCapacityPow2 = 18; - private final int hashCapacityMask = (1 << hashCapacityPow2) - 1; + private static final int hashCapacityPow2 = 18; - public FileRead(long startLocation, int bytesToRead, FileChannel fileChannel) { + static final int hashCapacityMask = (1 << hashCapacityPow2) - 1; + + public FileRead(FileChannel fileChannel, long startLocation, long bytesToRead, boolean firstSegment) { + this.fileChannel = fileChannel; this.currentLocation = startLocation; this.bytesToRead = bytesToRead; - this.fileChannel = fileChannel; } @Override @@ -211,21 +233,13 @@ public class CalculateAverage_bufistov { try { OpenHash openHash = new OpenHash(hashCapacityPow2); log("Reading the channel: " + currentLocation + ":" + bytesToRead); - byte[] suffix = new byte[128]; if (currentLocation > 0) { - toLineBegin(suffix); - } - while (bytesToRead > 0) { - int bufferSize = Math.min(1 << 24, bytesToRead); - MappedByteBuffer byteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, currentLocation, bufferSize); - bytesToRead -= bufferSize; - currentLocation += bufferSize; - int suffixBytes = 0; - if (currentLocation < fileChannel.size()) { - suffixBytes = toLineBegin(suffix); - } - processChunk(byteBuffer, bufferSize, suffix, suffixBytes, openHash); + toLineBeginPrefix(); } + toLineBeginSuffix(); + var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, currentLocation, bytesToRead, Arena.global()); + currentLocation = memorySegment.address(); + processChunk(openHash); log("Done Reading the channel: " + currentLocation + ":" + bytesToRead); return openHash.toJavaHashMap(); } @@ -240,39 +254,40 @@ public class CalculateAverage_bufistov { return byteBuffer.get(); } - int toLineBegin(byte[] suffix) throws IOException { - int bytesConsumed = 0; - if (getByte(currentLocation - 1) != LINE_SEPARATOR) { - while (getByte(currentLocation) != LINE_SEPARATOR) { // Small bug here if last chunk is less than a line and has no '\n' at the end. Valid input should have '\n' for all rows. - suffix[bytesConsumed++] = getByte(currentLocation); - ++currentLocation; - --bytesToRead; - } + void toLineBeginPrefix() throws IOException { + while (getByte(currentLocation - 1) != LINE_SEPARATOR) { ++currentLocation; --bytesToRead; } - return bytesConsumed; } - void processChunk(MappedByteBuffer byteBuffer, int bufferSize, byte[] suffix, int suffixBytes, OpenHash result) { - int nameBegin = 0; - int nameEnd = -1; - int numberBegin = -1; + void toLineBeginSuffix() throws IOException { + while (getByte(currentLocation + bytesToRead - 1) != LINE_SEPARATOR) { + ++bytesToRead; + } + } + + void processChunk(OpenHash result) { + long nameBegin = currentLocation; + long nameEnd = -1; + long numberBegin = -1; int currentHash = 0; int currentMask = 0; int nameHash = 0; - for (int currentPosition = 0; currentPosition < bufferSize; ++currentPosition) { - byte nextByte = byteBuffer.get(currentPosition); + long end = currentLocation + bytesToRead; + byte nextByte; + for (; currentLocation < end; ++currentLocation) { + nextByte = UNSAFE.getByte(currentLocation); if (nextByte == ';') { - nameEnd = currentPosition; - numberBegin = currentPosition + 1; + nameEnd = currentLocation; + numberBegin = currentLocation + 1; nameHash = currentHash & hashCapacityMask; } else if (nextByte == LINE_SEPARATOR) { - long value = getValue(byteBuffer, numberBegin, currentPosition); - // log("Station name: '" + getStationName(byteBuffer, nameBegin, nameEnd) + "' value: " + value + " hash: " + nameHash); - result.merge(byteBuffer, nameBegin, nameEnd, nameHash, value); - nameBegin = currentPosition + 1; + long value = getValue(numberBegin, currentLocation); + // log("Station name: '" + getStationName(nameBegin, nameEnd) + "' value: " + value + " hash: " + nameHash); + result.merge(nameBegin, nameEnd, nameHash, value); + nameBegin = currentLocation + 1; currentHash = 0; currentMask = 0; } @@ -281,38 +296,14 @@ public class CalculateAverage_bufistov { currentMask = (currentMask + 1) & 3; } } - if (nameBegin < bufferSize) { - byte[] lastLine = new byte[bufferSize - nameBegin + suffixBytes]; - byte[] prefix = new byte[bufferSize - nameBegin]; - byteBuffer.slice(nameBegin, prefix.length).get(prefix, 0, prefix.length); - System.arraycopy(prefix, 0, lastLine, 0, prefix.length); - System.arraycopy(suffix, 0, lastLine, prefix.length, suffixBytes); - processLastLine(lastLine, result); - } } - void processLastLine(byte[] lastLine, OpenHash result) { - int numberBegin = -1; - byte[] stationName = null; - for (int i = 0; i < lastLine.length; ++i) { - if (lastLine[i] == ';') { - stationName = new byte[i]; - System.arraycopy(lastLine, 0, stationName, 0, stationName.length); - numberBegin = i + 1; - break; - } - } - long value = getValue(lastLine, numberBegin); - // log("Station name: '" + new String(stationName, StandardCharsets.UTF_8) + "' value: " + value); - result.merge(stationName, value); - } - - long getValue(MappedByteBuffer byteBuffer, int startLocation, int endLocation) { - byte nextByte = byteBuffer.get(startLocation); + long getValue(long startLocation, long endLocation) { + byte nextByte = UNSAFE.getByte(startLocation); boolean negate = nextByte == '-'; long result = negate ? 0 : nextByte - '0'; - for (int i = startLocation + 1; i < endLocation; ++i) { - nextByte = byteBuffer.get(i); + for (long i = startLocation + 1; i < endLocation; ++i) { + nextByte = UNSAFE.getByte(i); if (nextByte != '.') { result *= 10; result += nextByte - '0'; @@ -321,23 +312,11 @@ public class CalculateAverage_bufistov { return negate ? -result : result; } - long getValue(byte[] lastLine, int startLocation) { - byte nextByte = lastLine[startLocation]; - boolean negate = nextByte == '-'; - long result = negate ? 0 : nextByte - '0'; - for (int i = startLocation + 1; i < lastLine.length; ++i) { - nextByte = lastLine[i]; - if (nextByte != '.') { - result *= 10; - result += nextByte - '0'; - } + String getStationName(long from, long to) { + byte[] bytes = new byte[(int) (to - from)]; + for (int i = 0; i < bytes.length; ++i) { + bytes[i] = UNSAFE.getByte(from + i); } - return negate ? -result : result; - } - - String getStationName(MappedByteBuffer byteBuffer, int from, int to) { - byte[] bytes = new byte[to - from]; - byteBuffer.slice(from, to - from).get(0, bytes); return new String(bytes, StandardCharsets.UTF_8); } } @@ -349,7 +328,7 @@ public class CalculateAverage_bufistov { } log("InputFile: " + fileName); FileInputStream fileInputStream = new FileInputStream(fileName); - int numThreads = 32; + int numThreads = 2 * Runtime.getRuntime().availableProcessors(); if (args.length > 1) { numThreads = Integer.parseInt(args[1]); } @@ -363,9 +342,12 @@ public class CalculateAverage_bufistov { long startLocation = 0; ArrayList>> results = new ArrayList<>(numThreads); + var fileChannel = FileChannel.open(Paths.get(fileName)); + boolean firstSegment = true; while (remaining_size > 0) { long actualSize = Math.min(chunk_size, remaining_size); - results.add(executor.submit(new FileRead(startLocation, toIntExact(actualSize), channel))); + results.add(executor.submit(new FileRead(fileChannel, startLocation, toIntExact(actualSize), firstSegment))); + firstSegment = false; remaining_size -= actualSize; startLocation += actualSize; }