diff --git a/calculate_average_isolgpus.sh b/calculate_average_isolgpus.sh new file mode 100755 index 0000000..9d48e59 --- /dev/null +++ b/calculate_average_isolgpus.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +JAVA_OPTS="--enable-preview" +time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_isolgpus diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_isolgpus.java b/src/main/java/dev/morling/onebrc/CalculateAverage_isolgpus.java new file mode 100644 index 0000000..65528c4 --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_isolgpus.java @@ -0,0 +1,293 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.nio.BufferUnderflowException; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Paths; +import java.util.*; +import java.util.concurrent.*; +import java.util.stream.Collectors; + +public class CalculateAverage_isolgpus { + + public static final int HISTOGRAMS_LENGTH = 1024 * 32; + public static final int HISTOGRAMS_MASK = HISTOGRAMS_LENGTH - 1; + public static final int THREAD_COUNT = 8; + private static final String FILE = "./measurements.txt"; + public static final byte SEPERATOR = 59; + public static final byte OFFSET = 48; + public static final byte NEGATIVE = 45; + public static final byte DECIMAL_POINT = 46; + public static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 100; // bit of wiggle room + public static final byte NEW_LINE = 10; + + public static void main(String[] args) throws IOException, InterruptedException, ExecutionException { + ExecutorService executorService = Executors.newFixedThreadPool(THREAD_COUNT); + + File file = Paths.get(FILE).toFile(); + long length = file.length(); + long chunksCount = Math.max(THREAD_COUNT, (int) Math.ceil(length / (double) MAX_CHUNK_SIZE)); + + long estimatedChunkSize = length / chunksCount; + + FileChannel channel = new RandomAccessFile(file, "r").getChannel(); + + List> futures = new ArrayList<>(); + for (int i = 0; i < chunksCount; i++) { + int finalI = i; + futures.add(executorService.submit(() -> handleChunk(channel, estimatedChunkSize * finalI, estimatedChunkSize, length))); + } + + List measurementCollectors = new ArrayList<>(); + for (Future result : futures) { + measurementCollectors.add(result.get()); + } + executorService.shutdown(); + + Map measurementCollectorsByCity = mergeMeasurements(measurementCollectors); + List results = measurementCollectorsByCity.values().stream().map(MeasurementResult::from).toList(); + + System.out.println("{" + results.stream().map(MeasurementResult::toString).collect(Collectors.joining(", ")) + "}"); + + } + + private static Map mergeMeasurements(List resultsFromAllChunk) { + Map mergedResults = new TreeMap<>(Comparator.naturalOrder()); + + for (int i = 0; i < HISTOGRAMS_LENGTH; i++) { + for (MeasurementCollector[] resultFromSpecificChunk : resultsFromAllChunk) { + MeasurementCollector measurementCollectorFromChunk = resultFromSpecificChunk[i]; + while (measurementCollectorFromChunk != null) { + MeasurementCollector currentMergedResult = mergedResults.get(new String(measurementCollectorFromChunk.name)); + if (currentMergedResult == null) { + currentMergedResult = new MeasurementCollector(measurementCollectorFromChunk.name); + mergedResults.put(new String(currentMergedResult.name), currentMergedResult); + } + currentMergedResult.merge(measurementCollectorFromChunk); + measurementCollectorFromChunk = measurementCollectorFromChunk.link; + } + } + } + + return mergedResults; + } + + // ----n--- + private static MeasurementCollector[] handleChunk(FileChannel channel, long estimatedStart, long lengthOfChunk, long maxLengthOfFile) throws IOException { + // -1 to see if we're starting on a brand new message + // +200 for wiggle room to finish the final message + + long seekStart = Math.max(estimatedStart - 1, 0); + long length = Math.min(lengthOfChunk + 200, maxLengthOfFile - seekStart); + + MappedByteBuffer r = channel.map(FileChannel.MapMode.READ_ONLY, seekStart, length); + + byte[] nameBuffer = new byte[100]; + boolean isNegative; + byte[] valueBuffer = new byte[3]; + MeasurementCollector[] measurementCollectors = new MeasurementCollector[HISTOGRAMS_LENGTH]; + int valueIndex = 0; + int nameBufferIndex = 0; + int nameSum = 0; + boolean parsingName = true; + long i = 0; + int hashResult = 0; + + // seek to the start of the next message + if (estimatedStart != 0) { + while (r.get() != NEW_LINE) { + i++; + } + i++; + } + + try { + + while (i <= lengthOfChunk || !parsingName) { + byte aChar; + if (parsingName) { + + while ((aChar = r.get()) != SEPERATOR) { + nameBuffer[nameBufferIndex++] = aChar; + nameSum += aChar; + hashResult = 31 * hashResult + aChar; + } + parsingName = false; + i += nameBufferIndex + 1; + } + else { + isNegative = (aChar = r.get()) == NEGATIVE; + valueIndex = readNumber(isNegative, valueBuffer, valueIndex, aChar, r); + + byte decimalValue = r.get(); + + int value = resolveValue(valueIndex, valueBuffer, decimalValue, isNegative); + // new line character + r.get(); + + MeasurementCollector measurementCollector = resolveMeasurementCollector(measurementCollectors, hashResult, nameBuffer, nameBufferIndex, nameSum); + + measurementCollector.feed(value); + i += valueIndex + (isNegative ? 4 : 3); + valueIndex = 0; + nameBufferIndex = 0; + nameSum = 0; + parsingName = true; + hashResult = 0; + } + } + + } + catch (BufferUnderflowException e) { + if (i != maxLengthOfFile - seekStart) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } + + return measurementCollectors; + } + + private static MeasurementCollector resolveMeasurementCollector(MeasurementCollector[] measurementCollectors, int hash, byte[] nameBuffer, int nameBufferIndex, + int nameSum) { + MeasurementCollector measurementCollector = measurementCollectors[hash & HISTOGRAMS_MASK]; + if (measurementCollector == null) { + measurementCollector = new MeasurementCollector(Arrays.copyOf(nameBuffer, nameBufferIndex)); + measurementCollectors[hash & HISTOGRAMS_MASK] = measurementCollector; + } + else { + // collision unhappy path, try to avoid + while (!nameEquals(measurementCollector.name, measurementCollector.nameSum, nameSum, nameBufferIndex)) { + if (measurementCollector.link == null) { + measurementCollector.link = new MeasurementCollector(Arrays.copyOf(nameBuffer, nameBufferIndex)); + measurementCollector = measurementCollector.link; + break; + } + else { + measurementCollector = measurementCollector.link; + } + } + + } + return measurementCollector; + } + + private static boolean nameEquals(byte[] existingName, int existingNameSum, int incomingNameSum, int nameBufferIndex) { + + if (existingName.length != nameBufferIndex) { + return false; + } + + return incomingNameSum == existingNameSum; + } + + private static int resolveValue(int valueIndex, byte[] valueBuffer, byte decimalValue, boolean isNegative) { + int value; + if (valueIndex == 1) { + value = ((valueBuffer[0] - OFFSET) * 10) + (decimalValue - OFFSET); + } + else // it's 2 digits + { + value = ((valueBuffer[0] - OFFSET) * 100) + ((valueBuffer[1] - OFFSET) * 10) + (decimalValue - OFFSET); + } + + if (isNegative) { + value = Math.negateExact(value); + } + return value; + } + + private static int readNumber(boolean isNegative, byte[] valueBuffer, int valueIndex, byte aChar, MappedByteBuffer r) { + if (!isNegative) { + valueBuffer[valueIndex++] = aChar; + } + + // maybe one or two more + while ((aChar = r.get()) != DECIMAL_POINT) { + valueBuffer[valueIndex++] = aChar; + } + return valueIndex; + } + + private static class MeasurementCollector { + private final byte[] name; + private final int nameSum; + public MeasurementCollector link; + private long sum; + private int count; + private int min = Integer.MAX_VALUE; + private int max = Integer.MIN_VALUE; + + public MeasurementCollector(byte[] name) { + + this.name = name; + int nameSum = 0; + for (int i = 0; i < name.length; i++) { + nameSum += name[i]; + } + this.nameSum = nameSum; + } + + public void feed(int value) { + sum += value; + count++; + min = Math.min(value, min); + max = Math.max(value, max); + } + + public void merge(MeasurementCollector measurementCollector) { + this.sum += measurementCollector.sum; + this.count += measurementCollector.count; + this.min = Math.min(measurementCollector.min, this.min); + this.max = Math.max(measurementCollector.max, this.max); + } + } + + private static class MeasurementResult { + private final String name; + private final double mean; + private final BigDecimal max; + private final BigDecimal min; + + public MeasurementResult(String name, double mean, BigDecimal max, BigDecimal min) { + + this.name = name; + this.mean = mean; + this.max = max; + this.min = min; + } + + @Override + public String toString() { + // Abha=-24.9/18.0/61.7 + return name + "=" + min + "/" + mean + "/" + max; + } + + public static MeasurementResult from(MeasurementCollector mc) { + double mean = Math.round((double) mc.sum / (double) mc.count) / 10d; + BigDecimal max = BigDecimal.valueOf(mc.max).divide(BigDecimal.TEN, 1, RoundingMode.HALF_UP); + BigDecimal min = BigDecimal.valueOf(mc.min).divide(BigDecimal.TEN, 1, RoundingMode.HALF_UP); + return new MeasurementResult(new String(mc.name), mean, max, min); + } + } +}