1brc/src/main/java/dev/morling/onebrc/CalculateAverage_isolgpus.java
Jamie Stansfield 4614b81eb6
isolgpus: submission 1
* isolgpus: submission 1

* isolgpus: fix min value bug (breaks if a negative temperature never appears)

* isolgpus: remove unused collector

* isolgpus: fix split on chunk bug

* isolgpus: change name equality algo to a cheaper check.

* isolgpus: fix chunking state to cope with last byte of last chunk

* isolgpus: hash as we go, instead of at the end

* isolgpus: adjust thread count to core count

* isolgpus: change cores to 8 statically

---------

Co-authored-by: Jamie Stansfield <jalstansfield@gmail.com>
2024-01-05 23:10:43 +01:00

294 lines
11 KiB
Java

/*
* Copyright 2023 The original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.morling.onebrc;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.BufferUnderflowException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.*;
import java.util.stream.Collectors;
public class CalculateAverage_isolgpus {
public static final int HISTOGRAMS_LENGTH = 1024 * 32;
public static final int HISTOGRAMS_MASK = HISTOGRAMS_LENGTH - 1;
public static final int THREAD_COUNT = 8;
private static final String FILE = "./measurements.txt";
public static final byte SEPERATOR = 59;
public static final byte OFFSET = 48;
public static final byte NEGATIVE = 45;
public static final byte DECIMAL_POINT = 46;
public static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 100; // bit of wiggle room
public static final byte NEW_LINE = 10;
public static void main(String[] args) throws IOException, InterruptedException, ExecutionException {
ExecutorService executorService = Executors.newFixedThreadPool(THREAD_COUNT);
File file = Paths.get(FILE).toFile();
long length = file.length();
long chunksCount = Math.max(THREAD_COUNT, (int) Math.ceil(length / (double) MAX_CHUNK_SIZE));
long estimatedChunkSize = length / chunksCount;
FileChannel channel = new RandomAccessFile(file, "r").getChannel();
List<Future<MeasurementCollector[]>> futures = new ArrayList<>();
for (int i = 0; i < chunksCount; i++) {
int finalI = i;
futures.add(executorService.submit(() -> handleChunk(channel, estimatedChunkSize * finalI, estimatedChunkSize, length)));
}
List<MeasurementCollector[]> measurementCollectors = new ArrayList<>();
for (Future<MeasurementCollector[]> result : futures) {
measurementCollectors.add(result.get());
}
executorService.shutdown();
Map<String, MeasurementCollector> measurementCollectorsByCity = mergeMeasurements(measurementCollectors);
List<MeasurementResult> results = measurementCollectorsByCity.values().stream().map(MeasurementResult::from).toList();
System.out.println("{" + results.stream().map(MeasurementResult::toString).collect(Collectors.joining(", ")) + "}");
}
private static Map<String, MeasurementCollector> mergeMeasurements(List<MeasurementCollector[]> resultsFromAllChunk) {
Map<String, MeasurementCollector> mergedResults = new TreeMap<>(Comparator.naturalOrder());
for (int i = 0; i < HISTOGRAMS_LENGTH; i++) {
for (MeasurementCollector[] resultFromSpecificChunk : resultsFromAllChunk) {
MeasurementCollector measurementCollectorFromChunk = resultFromSpecificChunk[i];
while (measurementCollectorFromChunk != null) {
MeasurementCollector currentMergedResult = mergedResults.get(new String(measurementCollectorFromChunk.name));
if (currentMergedResult == null) {
currentMergedResult = new MeasurementCollector(measurementCollectorFromChunk.name);
mergedResults.put(new String(currentMergedResult.name), currentMergedResult);
}
currentMergedResult.merge(measurementCollectorFromChunk);
measurementCollectorFromChunk = measurementCollectorFromChunk.link;
}
}
}
return mergedResults;
}
// ----n---
private static MeasurementCollector[] handleChunk(FileChannel channel, long estimatedStart, long lengthOfChunk, long maxLengthOfFile) throws IOException {
// -1 to see if we're starting on a brand new message
// +200 for wiggle room to finish the final message
long seekStart = Math.max(estimatedStart - 1, 0);
long length = Math.min(lengthOfChunk + 200, maxLengthOfFile - seekStart);
MappedByteBuffer r = channel.map(FileChannel.MapMode.READ_ONLY, seekStart, length);
byte[] nameBuffer = new byte[100];
boolean isNegative;
byte[] valueBuffer = new byte[3];
MeasurementCollector[] measurementCollectors = new MeasurementCollector[HISTOGRAMS_LENGTH];
int valueIndex = 0;
int nameBufferIndex = 0;
int nameSum = 0;
boolean parsingName = true;
long i = 0;
int hashResult = 0;
// seek to the start of the next message
if (estimatedStart != 0) {
while (r.get() != NEW_LINE) {
i++;
}
i++;
}
try {
while (i <= lengthOfChunk || !parsingName) {
byte aChar;
if (parsingName) {
while ((aChar = r.get()) != SEPERATOR) {
nameBuffer[nameBufferIndex++] = aChar;
nameSum += aChar;
hashResult = 31 * hashResult + aChar;
}
parsingName = false;
i += nameBufferIndex + 1;
}
else {
isNegative = (aChar = r.get()) == NEGATIVE;
valueIndex = readNumber(isNegative, valueBuffer, valueIndex, aChar, r);
byte decimalValue = r.get();
int value = resolveValue(valueIndex, valueBuffer, decimalValue, isNegative);
// new line character
r.get();
MeasurementCollector measurementCollector = resolveMeasurementCollector(measurementCollectors, hashResult, nameBuffer, nameBufferIndex, nameSum);
measurementCollector.feed(value);
i += valueIndex + (isNegative ? 4 : 3);
valueIndex = 0;
nameBufferIndex = 0;
nameSum = 0;
parsingName = true;
hashResult = 0;
}
}
}
catch (BufferUnderflowException e) {
if (i != maxLengthOfFile - seekStart) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
return measurementCollectors;
}
private static MeasurementCollector resolveMeasurementCollector(MeasurementCollector[] measurementCollectors, int hash, byte[] nameBuffer, int nameBufferIndex,
int nameSum) {
MeasurementCollector measurementCollector = measurementCollectors[hash & HISTOGRAMS_MASK];
if (measurementCollector == null) {
measurementCollector = new MeasurementCollector(Arrays.copyOf(nameBuffer, nameBufferIndex));
measurementCollectors[hash & HISTOGRAMS_MASK] = measurementCollector;
}
else {
// collision unhappy path, try to avoid
while (!nameEquals(measurementCollector.name, measurementCollector.nameSum, nameSum, nameBufferIndex)) {
if (measurementCollector.link == null) {
measurementCollector.link = new MeasurementCollector(Arrays.copyOf(nameBuffer, nameBufferIndex));
measurementCollector = measurementCollector.link;
break;
}
else {
measurementCollector = measurementCollector.link;
}
}
}
return measurementCollector;
}
private static boolean nameEquals(byte[] existingName, int existingNameSum, int incomingNameSum, int nameBufferIndex) {
if (existingName.length != nameBufferIndex) {
return false;
}
return incomingNameSum == existingNameSum;
}
private static int resolveValue(int valueIndex, byte[] valueBuffer, byte decimalValue, boolean isNegative) {
int value;
if (valueIndex == 1) {
value = ((valueBuffer[0] - OFFSET) * 10) + (decimalValue - OFFSET);
}
else // it's 2 digits
{
value = ((valueBuffer[0] - OFFSET) * 100) + ((valueBuffer[1] - OFFSET) * 10) + (decimalValue - OFFSET);
}
if (isNegative) {
value = Math.negateExact(value);
}
return value;
}
private static int readNumber(boolean isNegative, byte[] valueBuffer, int valueIndex, byte aChar, MappedByteBuffer r) {
if (!isNegative) {
valueBuffer[valueIndex++] = aChar;
}
// maybe one or two more
while ((aChar = r.get()) != DECIMAL_POINT) {
valueBuffer[valueIndex++] = aChar;
}
return valueIndex;
}
private static class MeasurementCollector {
private final byte[] name;
private final int nameSum;
public MeasurementCollector link;
private long sum;
private int count;
private int min = Integer.MAX_VALUE;
private int max = Integer.MIN_VALUE;
public MeasurementCollector(byte[] name) {
this.name = name;
int nameSum = 0;
for (int i = 0; i < name.length; i++) {
nameSum += name[i];
}
this.nameSum = nameSum;
}
public void feed(int value) {
sum += value;
count++;
min = Math.min(value, min);
max = Math.max(value, max);
}
public void merge(MeasurementCollector measurementCollector) {
this.sum += measurementCollector.sum;
this.count += measurementCollector.count;
this.min = Math.min(measurementCollector.min, this.min);
this.max = Math.max(measurementCollector.max, this.max);
}
}
private static class MeasurementResult {
private final String name;
private final double mean;
private final BigDecimal max;
private final BigDecimal min;
public MeasurementResult(String name, double mean, BigDecimal max, BigDecimal min) {
this.name = name;
this.mean = mean;
this.max = max;
this.min = min;
}
@Override
public String toString() {
// Abha=-24.9/18.0/61.7
return name + "=" + min + "/" + mean + "/" + max;
}
public static MeasurementResult from(MeasurementCollector mc) {
double mean = Math.round((double) mc.sum / (double) mc.count) / 10d;
BigDecimal max = BigDecimal.valueOf(mc.max).divide(BigDecimal.TEN, 1, RoundingMode.HALF_UP);
BigDecimal min = BigDecimal.valueOf(mc.min).divide(BigDecimal.TEN, 1, RoundingMode.HALF_UP);
return new MeasurementResult(new String(mc.name), mean, max, min);
}
}
}