1brc/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java

210 lines
8.1 KiB
Java
Raw Normal View History

/*
* Copyright 2023 The original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.morling.onebrc;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorOperators;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.charset.StandardCharsets;
import java.util.*;
/**
* Approach:
* - Use memory-mapped file to speed up loading data into memory
* - Partition data, compute aggregates for partitions in parallel, and finally combine results from all partitions
* - Apply SIMD instructions for computing min/max/sum aggregates
* - Use Shorts for storing aggregates of partitions, so we maximize the SIMD parallelism
*/
public class CalculateAverage_flippingbits {
private static final String FILE = "./measurements.txt";
private static final long CHUNK_SIZE = 10 * 1024 * 1024; // 10 MB
private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length();
private static final int MAX_STATION_NAME_LENGTH = 100;
public static void main(String[] args) throws IOException {
var result = Arrays.asList(getSegments()).stream()
.map(segment -> {
try {
return processSegment(segment[0], segment[1]);
}
catch (IOException e) {
throw new RuntimeException(e);
}
})
.parallel()
.reduce((firstMap, secondMap) -> {
for (var entry : secondMap.entrySet()) {
PartitionAggregate firstAggregate = firstMap.get(entry.getKey());
if (firstAggregate == null) {
firstMap.put(entry.getKey(), entry.getValue());
}
else {
firstAggregate.mergeWith(entry.getValue());
}
}
return firstMap;
})
.map(TreeMap::new).get();
System.out.println(result);
}
private static long[][] getSegments() throws IOException {
try (var file = new RandomAccessFile(FILE, "r")) {
var fileSize = file.length();
// Split file into segments, so we can work around the size limitation of channels
var numSegments = (int) (fileSize / CHUNK_SIZE);
var boundaries = new long[numSegments + 1][2];
var endPointer = 0L;
for (var i = 0; i < numSegments; i++) {
// Start of segment
boundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize);
// Seek end of segment, limited by the end of the file
file.seek(Math.min(boundaries[i][0] + CHUNK_SIZE - 1, fileSize));
// Extend segment until end of line or file
while (file.read() != '\n') {
}
// End of segment
endPointer = file.getFilePointer();
boundaries[i][1] = endPointer;
}
boundaries[numSegments][0] = Math.max(endPointer, numSegments * CHUNK_SIZE);
boundaries[numSegments][1] = fileSize;
return boundaries;
}
}
private static Map<String, PartitionAggregate> processSegment(long startOfSegment, long endOfSegment)
throws IOException {
Map<String, PartitionAggregate> stationAggregates = new HashMap<>(50_000);
var byteChunk = new byte[(int) (endOfSegment - startOfSegment)];
var stationBuffer = new byte[MAX_STATION_NAME_LENGTH];
try (var file = new RandomAccessFile(FILE, "r")) {
file.seek(startOfSegment);
file.read(byteChunk);
var i = 0;
while (i < byteChunk.length) {
// Station name has at least one byte
stationBuffer[0] = byteChunk[i];
i++;
// Read station name
var j = 1;
while (byteChunk[i] != ';') {
stationBuffer[j] = byteChunk[i];
j++;
i++;
}
var station = new String(stationBuffer, 0, j, StandardCharsets.UTF_8);
i++;
// Read measurement
var isNegative = byteChunk[i] == '-';
var measurement = 0;
if (isNegative) {
i++;
while (byteChunk[i] != '.') {
measurement = measurement * 10 + byteChunk[i] - '0';
i++;
}
measurement = (measurement * 10 + byteChunk[i + 1] - '0') * -1;
}
else {
while (byteChunk[i] != '.') {
measurement = measurement * 10 + byteChunk[i] - '0';
i++;
}
measurement = measurement * 10 + byteChunk[i + 1] - '0';
}
// Update aggregate
var aggregate = stationAggregates.computeIfAbsent(station, x -> new PartitionAggregate());
aggregate.addMeasurementAndComputeAggregate((short) measurement);
i += 3;
}
stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements);
}
return stationAggregates;
}
private static class PartitionAggregate {
final short[] doubleLane = new short[SIMD_LANE_LENGTH * 2];
// Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition
int count = 0;
long sum = 0;
short min = Short.MAX_VALUE;
short max = Short.MIN_VALUE;
public void addMeasurementAndComputeAggregate(short measurement) {
// Add measurement to buffer, which is later processed by SIMD instructions
doubleLane[count % doubleLane.length] = measurement;
count++;
// Once lane is full, use SIMD instructions to calculate aggregates
if (count % doubleLane.length == 0) {
var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, 0);
var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, SIMD_LANE_LENGTH);
var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN);
min = (short) Math.min(min, simdMin);
var simdMax = firstVector.max(secondVector).reduceLanes(VectorOperators.MAX);
max = (short) Math.max(max, simdMax);
sum += firstVector.add(secondVector).reduceLanes(VectorOperators.ADD);
}
}
public void aggregateRemainingMeasurements() {
for (var i = 0; i < count % doubleLane.length; i++) {
var measurement = doubleLane[i];
min = (short) Math.min(min, measurement);
max = (short) Math.max(max, measurement);
sum += measurement;
}
}
public void mergeWith(PartitionAggregate otherAggregate) {
min = (short) Math.min(min, otherAggregate.min);
max = (short) Math.max(max, otherAggregate.max);
count = count + otherAggregate.count;
sum = sum + otherAggregate.sum;
}
public String toString() {
return String.format(
Locale.US,
"%.1f/%.1f/%.1f",
(min / 10.0),
((sum / 10.0) / count),
(max / 10.0));
}
}
}