diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java b/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java index 33ca88c..2510d85 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java @@ -20,6 +20,7 @@ import jdk.incubator.vector.VectorOperators; import java.io.IOException; import java.io.RandomAccessFile; +import java.nio.charset.StandardCharsets; import java.util.*; /** @@ -33,103 +34,118 @@ public class CalculateAverage_flippingbits { private static final String FILE = "./measurements.txt"; - private static final long CHUNK_SIZE = 100 * 1024 * 1024; // 100 MB + private static final long CHUNK_SIZE = 10 * 1024 * 1024; // 10 MB private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length(); - public static void main(String[] args) throws IOException { - try (var file = new RandomAccessFile(FILE, "r")) { - // Calculate chunk boundaries - long[][] chunkBoundaries = getChunkBoundaries(file); - // Process chunks - var result = Arrays.asList(chunkBoundaries).stream() - .map(chunk -> { - try { - return processChunk(chunk[0], chunk[1]); - } - catch (IOException e) { - throw new RuntimeException(e); - } - }) - .parallel() - .reduce((firstMap, secondMap) -> { - for (var entry : secondMap.entrySet()) { - PartitionAggregate firstAggregate = firstMap.get(entry.getKey()); - if (firstAggregate == null) { - firstMap.put(entry.getKey(), entry.getValue()); - } - else { - firstAggregate.mergeWith(entry.getValue()); - } - } - return firstMap; - }) - .map(hashMap -> new TreeMap(hashMap)).get(); + private static final int MAX_STATION_NAME_LENGTH = 100; - System.out.println(result); - } + public static void main(String[] args) throws IOException { + var result = Arrays.asList(getSegments()).stream() + .map(segment -> { + try { + return processSegment(segment[0], segment[1]); + } + catch (IOException e) { + throw new RuntimeException(e); + } + }) + .parallel() + .reduce((firstMap, secondMap) -> { + for (var entry : secondMap.entrySet()) { + PartitionAggregate firstAggregate = firstMap.get(entry.getKey()); + if (firstAggregate == null) { + firstMap.put(entry.getKey(), entry.getValue()); + } + else { + firstAggregate.mergeWith(entry.getValue()); + } + } + return firstMap; + }) + .map(TreeMap::new).get(); + + System.out.println(result); } - private static long[][] getChunkBoundaries(RandomAccessFile file) throws IOException { - var fileSize = file.length(); - // Split file into chunks, so we can work around the size limitation of channels - var chunks = (int) (fileSize / CHUNK_SIZE); + private static long[][] getSegments() throws IOException { + try (var file = new RandomAccessFile(FILE, "r")) { + var fileSize = file.length(); + // Split file into segments, so we can work around the size limitation of channels + var numSegments = (int) (fileSize / CHUNK_SIZE); - long[][] chunkBoundaries = new long[chunks + 1][2]; - var endPointer = 0L; + var boundaries = new long[numSegments + 1][2]; + var endPointer = 0L; - for (var i = 0; i <= chunks; i++) { - // Start of chunk - chunkBoundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize); + for (var i = 0; i < numSegments; i++) { + // Start of segment + boundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize); - // Seek end of chunk, limited by the end of the file - file.seek(Math.min(chunkBoundaries[i][0] + CHUNK_SIZE - 1, fileSize)); + // Seek end of segment, limited by the end of the file + file.seek(Math.min(boundaries[i][0] + CHUNK_SIZE - 1, fileSize)); - // Extend chunk until end of line or file - while (true) { - var character = file.read(); - if (character == '\n' || character == -1) { - break; + // Extend segment until end of line or file + while (file.read() != '\n') { } + + // End of segment + endPointer = file.getFilePointer(); + boundaries[i][1] = endPointer; } - // End of chunk - endPointer = file.getFilePointer(); - chunkBoundaries[i][1] = endPointer; - } + boundaries[numSegments][0] = Math.max(endPointer, numSegments * CHUNK_SIZE); + boundaries[numSegments][1] = fileSize; - return chunkBoundaries; + return boundaries; + } } - private static Map processChunk(long startOfChunk, long endOfChunk) + private static Map processSegment(long startOfSegment, long endOfSegment) throws IOException { - Map stationAggregates = new HashMap<>(10_000); - byte[] byteChunk = new byte[(int) (endOfChunk - startOfChunk)]; + Map stationAggregates = new HashMap<>(50_000); + var byteChunk = new byte[(int) (endOfSegment - startOfSegment)]; + var stationBuffer = new byte[MAX_STATION_NAME_LENGTH]; try (var file = new RandomAccessFile(FILE, "r")) { - file.seek(startOfChunk); + file.seek(startOfSegment); file.read(byteChunk); var i = 0; while (i < byteChunk.length) { - final var startPosStation = i; - - // read station name + // Station name has at least one byte + stationBuffer[0] = byteChunk[i]; + i++; + // Read station name + var j = 1; while (byteChunk[i] != ';') { + stationBuffer[j] = byteChunk[i]; + j++; i++; } - var station = new String(Arrays.copyOfRange(byteChunk, startPosStation, i)); + var station = new String(stationBuffer, 0, j, StandardCharsets.UTF_8); i++; - // read measurement - final var startPosMeasurement = i; - while (byteChunk[i] != '\n') { + // Read measurement + var isNegative = byteChunk[i] == '-'; + var measurement = 0; + if (isNegative) { i++; + while (byteChunk[i] != '.') { + measurement = measurement * 10 + byteChunk[i] - '0'; + i++; + } + measurement = (measurement * 10 + byteChunk[i + 1] - '0') * -1; + } + else { + while (byteChunk[i] != '.') { + measurement = measurement * 10 + byteChunk[i] - '0'; + i++; + } + measurement = measurement * 10 + byteChunk[i + 1] - '0'; } - var measurement = Arrays.copyOfRange(byteChunk, startPosMeasurement, i); - var aggregate = stationAggregates.getOrDefault(station, new PartitionAggregate()); - aggregate.addMeasurementAndComputeAggregate(measurement); - stationAggregates.put(station, aggregate); - i++; + // Update aggregate + var aggregate = stationAggregates.computeIfAbsent(station, x -> new PartitionAggregate()); + aggregate.addMeasurementAndComputeAggregate((short) measurement); + i += 3; } stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements); } @@ -138,40 +154,22 @@ public class CalculateAverage_flippingbits { } private static class PartitionAggregate { - final short[] lane = new short[SIMD_LANE_LENGTH * 2]; + final short[] doubleLane = new short[SIMD_LANE_LENGTH * 2]; // Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition int count = 0; long sum = 0; short min = Short.MAX_VALUE; short max = Short.MIN_VALUE; - public void addMeasurementAndComputeAggregate(byte[] measurementBytes) { - // Parse measurement and exploit that we know the format of the floating-point values - var measurement = measurementBytes[measurementBytes.length - 1] - '0'; - var digits = 1; - for (var i = measurementBytes.length - 3; i > 0; i--) { - var num = measurementBytes[i] - '0'; - measurement = measurement + (num * (int) Math.pow(10, digits)); - digits++; - } - - // Check if measurement is negative - if (measurementBytes[0] == '-') { - measurement = measurement * -1; - } - else { - var num = measurementBytes[0] - '0'; - measurement = measurement + (num * (int) Math.pow(10, digits)); - } - + public void addMeasurementAndComputeAggregate(short measurement) { // Add measurement to buffer, which is later processed by SIMD instructions - lane[count % lane.length] = (short) measurement; + doubleLane[count % doubleLane.length] = measurement; count++; // Once lane is full, use SIMD instructions to calculate aggregates - if (count % lane.length == 0) { - var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, lane, 0); - var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, lane, SIMD_LANE_LENGTH); + if (count % doubleLane.length == 0) { + var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, 0); + var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, SIMD_LANE_LENGTH); var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN); min = (short) Math.min(min, simdMin); @@ -184,8 +182,8 @@ public class CalculateAverage_flippingbits { } public void aggregateRemainingMeasurements() { - for (var i = 0; i < count % lane.length; i++) { - var measurement = lane[i]; + for (var i = 0; i < count % doubleLane.length; i++) { + var measurement = doubleLane[i]; min = (short) Math.min(min, measurement); max = (short) Math.max(max, measurement); sum += measurement; @@ -204,7 +202,7 @@ public class CalculateAverage_flippingbits { Locale.US, "%.1f/%.1f/%.1f", (min / 10.0), - (sum / 10.0) / count, + ((sum / 10.0) / count), (max / 10.0)); } }