diff --git a/calculate_average_dmitry-midokura.sh b/calculate_average_dmitry-midokura.sh new file mode 100755 index 0000000..e4d1366 --- /dev/null +++ b/calculate_average_dmitry-midokura.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +#JAVA_OPTS="-verbose:gc" +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_bufistov $1 $2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java b/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java new file mode 100644 index 0000000..db60403 --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java @@ -0,0 +1,398 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import static java.lang.Math.toIntExact; + +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import java.io.FileInputStream; +import java.io.IOException; +import java.util.concurrent.Future; + +class ResultRow { + byte[] station; + + String stationString; + long min, max, count, suma; + + ResultRow() { + } + + ResultRow(byte[] station, long value) { + this.station = new byte[station.length]; + System.arraycopy(station, 0, this.station, 0, station.length); + this.min = value; + this.max = value; + this.count = 1; + this.suma = value; + } + + ResultRow(long value) { + this.min = value; + this.max = value; + this.count = 1; + this.suma = value; + } + + void setStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition) { + this.station = new byte[endPosition - startPosition]; + byteBuffer.slice(startPosition, station.length).get(this.station, 0, station.length); + } + + public String toString() { + stationString = new String(station, StandardCharsets.UTF_8); + return stationString + "=" + round(min / 10.0) + "/" + round(suma / 10.0 / count) + "/" + round(max / 10.0); + } + + private double round(double value) { + return Math.round(value * 10.0) / 10.0; + } + + ResultRow update(long newValue) { + this.count += 1; + this.suma += newValue; + if (newValue < this.min) { + this.min = newValue; + } + else if (newValue > this.max) { + this.max = newValue; + } + return this; + } + + ResultRow merge(ResultRow another) { + this.count += another.count; + this.suma += another.suma; + this.min = Math.min(this.min, another.min); + this.max = Math.max(this.max, another.max); + return this; + } +} + +class ByteArrayWrapper { + private final byte[] data; + + public ByteArrayWrapper(byte[] data) { + this.data = data; + } + + @Override + public boolean equals(Object other) { + return Arrays.equals(data, ((ByteArrayWrapper) other).data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } +} + +class OpenHash { + ResultRow[] data; + int dataSizeMask; + + // ResultRow metrics = new ResultRow(); + + public OpenHash(int capacityPow2) { + assert capacityPow2 <= 20; + int dataSize = 1 << capacityPow2; + dataSizeMask = dataSize - 1; + data = new ResultRow[dataSize]; + } + + int hashByteArray(byte[] array) { + int result = 0; + long mask = 0; + for (int i = 0; i < array.length; ++i, mask = ((mask + 1) & 3)) { + result += array[i] << mask; + } + return result & dataSizeMask; + } + + void merge(byte[] station, long value, int hashValue) { + while (data[hashValue] != null && !Arrays.equals(station, data[hashValue].station)) { + hashValue += 1; + hashValue &= dataSizeMask; + } + if (data[hashValue] == null) { + data[hashValue] = new ResultRow(station, value); + } + else { + data[hashValue].update(value); + } + // metrics.update(delta); + } + + void merge(byte[] station, long value) { + merge(station, value, hashByteArray(station)); + } + + void merge(MappedByteBuffer byteBuffer, final int startPosition, final int endPosition, int hashValue, final long value) { + while (data[hashValue] != null && !equalsToStation(byteBuffer, startPosition, endPosition, data[hashValue].station)) { + hashValue += 1; + hashValue &= dataSizeMask; + } + if (data[hashValue] == null) { + data[hashValue] = new ResultRow(value); + data[hashValue].setStation(byteBuffer, startPosition, endPosition); + } + else { + data[hashValue].update(value); + } + } + + boolean equalsToStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition, byte[] station) { + if (endPosition - startPosition != station.length) { + return false; + } + for (int i = 0; i < station.length; ++i, ++startPosition) { + if (byteBuffer.get(startPosition) != station[i]) + return false; + } + return true; + } + + HashMap toJavaHashMap() { + HashMap result = new HashMap<>(20000); + for (int i = 0; i < data.length; ++i) { + if (data[i] != null) { + var key = new ByteArrayWrapper(data[i].station); + result.put(key, data[i]); + } + } + return result; + } +} + +public class CalculateAverage_bufistov { + + static final long LINE_SEPARATOR = '\n'; + + public static class FileRead implements Callable> { + + private final FileChannel fileChannel; + private long currentLocation; + private int bytesToRead; + + private final int hashCapacityPow2 = 18; + private final int hashCapacityMask = (1 << hashCapacityPow2) - 1; + + public FileRead(long startLocation, int bytesToRead, FileChannel fileChannel) { + this.currentLocation = startLocation; + this.bytesToRead = bytesToRead; + this.fileChannel = fileChannel; + } + + @Override + public HashMap call() throws IOException { + try { + OpenHash openHash = new OpenHash(hashCapacityPow2); + log("Reading the channel: " + currentLocation + ":" + bytesToRead); + byte[] suffix = new byte[128]; + if (currentLocation > 0) { + toLineBegin(suffix); + } + while (bytesToRead > 0) { + int bufferSize = Math.min(1 << 24, bytesToRead); + MappedByteBuffer byteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, currentLocation, bufferSize); + bytesToRead -= bufferSize; + currentLocation += bufferSize; + int suffixBytes = 0; + if (currentLocation < fileChannel.size()) { + suffixBytes = toLineBegin(suffix); + } + processChunk(byteBuffer, bufferSize, suffix, suffixBytes, openHash); + } + log("Done Reading the channel: " + currentLocation + ":" + bytesToRead); + return openHash.toJavaHashMap(); + } + catch (Exception e) { + e.printStackTrace(); + throw e; + } + } + + byte getByte(long position) throws IOException { + MappedByteBuffer byteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, position, 1); + return byteBuffer.get(); + } + + int toLineBegin(byte[] suffix) throws IOException { + int bytesConsumed = 0; + if (getByte(currentLocation - 1) != LINE_SEPARATOR) { + while (getByte(currentLocation) != LINE_SEPARATOR) { // Small bug here if last chunk is less than a line and has no '\n' at the end. Valid input should have '\n' for all rows. + suffix[bytesConsumed++] = getByte(currentLocation); + ++currentLocation; + --bytesToRead; + } + ++currentLocation; + --bytesToRead; + } + return bytesConsumed; + } + + void processChunk(MappedByteBuffer byteBuffer, int bufferSize, byte[] suffix, int suffixBytes, OpenHash result) { + int nameBegin = 0; + int nameEnd = -1; + int numberBegin = -1; + int currentHash = 0; + int currentMask = 0; + int nameHash = 0; + for (int currentPosition = 0; currentPosition < bufferSize; ++currentPosition) { + byte nextByte = byteBuffer.get(currentPosition); + if (nextByte == ';') { + nameEnd = currentPosition; + numberBegin = currentPosition + 1; + nameHash = currentHash & hashCapacityMask; + } + else if (nextByte == LINE_SEPARATOR) { + long value = getValue(byteBuffer, numberBegin, currentPosition); + // log("Station name: '" + getStationName(byteBuffer, nameBegin, nameEnd) + "' value: " + value + " hash: " + nameHash); + result.merge(byteBuffer, nameBegin, nameEnd, nameHash, value); + nameBegin = currentPosition + 1; + currentHash = 0; + currentMask = 0; + } + else { + currentHash += (nextByte << currentMask); + currentMask = (currentMask + 1) & 3; + } + } + if (nameBegin < bufferSize) { + byte[] lastLine = new byte[bufferSize - nameBegin + suffixBytes]; + byte[] prefix = new byte[bufferSize - nameBegin]; + byteBuffer.slice(nameBegin, prefix.length).get(prefix, 0, prefix.length); + System.arraycopy(prefix, 0, lastLine, 0, prefix.length); + System.arraycopy(suffix, 0, lastLine, prefix.length, suffixBytes); + processLastLine(lastLine, result); + } + } + + void processLastLine(byte[] lastLine, OpenHash result) { + int numberBegin = -1; + byte[] stationName = null; + for (int i = 0; i < lastLine.length; ++i) { + if (lastLine[i] == ';') { + stationName = new byte[i]; + System.arraycopy(lastLine, 0, stationName, 0, stationName.length); + numberBegin = i + 1; + break; + } + } + long value = getValue(lastLine, numberBegin); + // log("Station name: '" + new String(stationName, StandardCharsets.UTF_8) + "' value: " + value); + result.merge(stationName, value); + } + + long getValue(MappedByteBuffer byteBuffer, int startLocation, int endLocation) { + byte nextByte = byteBuffer.get(startLocation); + boolean negate = nextByte == '-'; + long result = negate ? 0 : nextByte - '0'; + for (int i = startLocation + 1; i < endLocation; ++i) { + nextByte = byteBuffer.get(i); + if (nextByte != '.') { + result *= 10; + result += nextByte - '0'; + } + } + return negate ? -result : result; + } + + long getValue(byte[] lastLine, int startLocation) { + byte nextByte = lastLine[startLocation]; + boolean negate = nextByte == '-'; + long result = negate ? 0 : nextByte - '0'; + for (int i = startLocation + 1; i < lastLine.length; ++i) { + nextByte = lastLine[i]; + if (nextByte != '.') { + result *= 10; + result += nextByte - '0'; + } + } + return negate ? -result : result; + } + + String getStationName(MappedByteBuffer byteBuffer, int from, int to) { + byte[] bytes = new byte[to - from]; + byteBuffer.slice(from, to - from).get(0, bytes); + return new String(bytes, StandardCharsets.UTF_8); + } + } + + public static void main(String[] args) throws Exception { + String fileName = "measurements.txt"; + if (args.length > 0 && args[0].length() > 0) { + fileName = args[0]; + } + log("InputFile: " + fileName); + FileInputStream fileInputStream = new FileInputStream(fileName); + int numThreads = 32; + if (args.length > 1) { + numThreads = Integer.parseInt(args[1]); + } + log("NumThreads: " + numThreads); + FileChannel channel = fileInputStream.getChannel(); + final long fileSize = channel.size(); + long remaining_size = fileSize; + long chunk_size = Math.min((fileSize + numThreads - 1) / numThreads, Integer.MAX_VALUE - 5); + + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + + long startLocation = 0; + ArrayList>> results = new ArrayList<>(numThreads); + while (remaining_size > 0) { + long actualSize = Math.min(chunk_size, remaining_size); + results.add(executor.submit(new FileRead(startLocation, toIntExact(actualSize), channel))); + remaining_size -= actualSize; + startLocation += actualSize; + } + executor.shutdown(); + + // Wait for all threads to finish + while (!executor.isTerminated()) { + Thread.yield(); + } + log("Finished all threads"); + fileInputStream.close(); + HashMap result = new HashMap<>(20000); + for (var future : results) { + for (var entry : future.get().entrySet()) { + result.merge(entry.getKey(), entry.getValue(), ResultRow::merge); + } + } + ResultRow[] finalResult = result.values().toArray(new ResultRow[0]); + for (var row : finalResult) { + row.toString(); + } + Arrays.sort(finalResult, Comparator.comparing(a -> a.stationString)); + System.out.println("{" + String.join(", ", Arrays.stream(finalResult).map(ResultRow::toString).toList()) + "}"); + log("All done!"); + } + + static void log(String message) { + // System.err.println(Instant.now() + "[" + Thread.currentThread().getName() + "]: " + message); + } +}