Second submission by flippingbits - 50% performance improvement

* feat(flippingbits): Improve parsing of measurement and few cleanups

* feat(flippingbits): Reduce chunk size to 10MB

* feat(flippingbits): Improve parsing of station names

* chore(flippingbits): Remove obsolete import

* chore(flippingbits): Few cleanups
This commit is contained in:
Stefan Sprenger 2024-01-10 20:36:22 +01:00 committed by GitHub
parent 97b1f014ad
commit a8a3876416
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,6 +20,7 @@ import jdk.incubator.vector.VectorOperators;
import java.io.IOException; import java.io.IOException;
import java.io.RandomAccessFile; import java.io.RandomAccessFile;
import java.nio.charset.StandardCharsets;
import java.util.*; import java.util.*;
/** /**
@ -33,19 +34,17 @@ public class CalculateAverage_flippingbits {
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static final long CHUNK_SIZE = 100 * 1024 * 1024; // 100 MB private static final long CHUNK_SIZE = 10 * 1024 * 1024; // 10 MB
private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length(); private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length();
private static final int MAX_STATION_NAME_LENGTH = 100;
public static void main(String[] args) throws IOException { public static void main(String[] args) throws IOException {
try (var file = new RandomAccessFile(FILE, "r")) { var result = Arrays.asList(getSegments()).stream()
// Calculate chunk boundaries .map(segment -> {
long[][] chunkBoundaries = getChunkBoundaries(file);
// Process chunks
var result = Arrays.asList(chunkBoundaries).stream()
.map(chunk -> {
try { try {
return processChunk(chunk[0], chunk[1]); return processSegment(segment[0], segment[1]);
} }
catch (IOException e) { catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
@ -64,73 +63,90 @@ public class CalculateAverage_flippingbits {
} }
return firstMap; return firstMap;
}) })
.map(hashMap -> new TreeMap(hashMap)).get(); .map(TreeMap::new).get();
System.out.println(result); System.out.println(result);
} }
}
private static long[][] getChunkBoundaries(RandomAccessFile file) throws IOException { private static long[][] getSegments() throws IOException {
try (var file = new RandomAccessFile(FILE, "r")) {
var fileSize = file.length(); var fileSize = file.length();
// Split file into chunks, so we can work around the size limitation of channels // Split file into segments, so we can work around the size limitation of channels
var chunks = (int) (fileSize / CHUNK_SIZE); var numSegments = (int) (fileSize / CHUNK_SIZE);
long[][] chunkBoundaries = new long[chunks + 1][2]; var boundaries = new long[numSegments + 1][2];
var endPointer = 0L; var endPointer = 0L;
for (var i = 0; i <= chunks; i++) { for (var i = 0; i < numSegments; i++) {
// Start of chunk // Start of segment
chunkBoundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize); boundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize);
// Seek end of chunk, limited by the end of the file // Seek end of segment, limited by the end of the file
file.seek(Math.min(chunkBoundaries[i][0] + CHUNK_SIZE - 1, fileSize)); file.seek(Math.min(boundaries[i][0] + CHUNK_SIZE - 1, fileSize));
// Extend chunk until end of line or file // Extend segment until end of line or file
while (true) { while (file.read() != '\n') {
var character = file.read();
if (character == '\n' || character == -1) {
break;
}
} }
// End of chunk // End of segment
endPointer = file.getFilePointer(); endPointer = file.getFilePointer();
chunkBoundaries[i][1] = endPointer; boundaries[i][1] = endPointer;
} }
return chunkBoundaries; boundaries[numSegments][0] = Math.max(endPointer, numSegments * CHUNK_SIZE);
boundaries[numSegments][1] = fileSize;
return boundaries;
}
} }
private static Map<String, PartitionAggregate> processChunk(long startOfChunk, long endOfChunk) private static Map<String, PartitionAggregate> processSegment(long startOfSegment, long endOfSegment)
throws IOException { throws IOException {
Map<String, PartitionAggregate> stationAggregates = new HashMap<>(10_000); Map<String, PartitionAggregate> stationAggregates = new HashMap<>(50_000);
byte[] byteChunk = new byte[(int) (endOfChunk - startOfChunk)]; var byteChunk = new byte[(int) (endOfSegment - startOfSegment)];
var stationBuffer = new byte[MAX_STATION_NAME_LENGTH];
try (var file = new RandomAccessFile(FILE, "r")) { try (var file = new RandomAccessFile(FILE, "r")) {
file.seek(startOfChunk); file.seek(startOfSegment);
file.read(byteChunk); file.read(byteChunk);
var i = 0; var i = 0;
while (i < byteChunk.length) { while (i < byteChunk.length) {
final var startPosStation = i; // Station name has at least one byte
stationBuffer[0] = byteChunk[i];
// read station name i++;
// Read station name
var j = 1;
while (byteChunk[i] != ';') { while (byteChunk[i] != ';') {
stationBuffer[j] = byteChunk[i];
j++;
i++; i++;
} }
var station = new String(Arrays.copyOfRange(byteChunk, startPosStation, i)); var station = new String(stationBuffer, 0, j, StandardCharsets.UTF_8);
i++; i++;
// read measurement // Read measurement
final var startPosMeasurement = i; var isNegative = byteChunk[i] == '-';
while (byteChunk[i] != '\n') { var measurement = 0;
if (isNegative) {
i++;
while (byteChunk[i] != '.') {
measurement = measurement * 10 + byteChunk[i] - '0';
i++; i++;
} }
measurement = (measurement * 10 + byteChunk[i + 1] - '0') * -1;
var measurement = Arrays.copyOfRange(byteChunk, startPosMeasurement, i); }
var aggregate = stationAggregates.getOrDefault(station, new PartitionAggregate()); else {
aggregate.addMeasurementAndComputeAggregate(measurement); while (byteChunk[i] != '.') {
stationAggregates.put(station, aggregate); measurement = measurement * 10 + byteChunk[i] - '0';
i++; i++;
} }
measurement = measurement * 10 + byteChunk[i + 1] - '0';
}
// Update aggregate
var aggregate = stationAggregates.computeIfAbsent(station, x -> new PartitionAggregate());
aggregate.addMeasurementAndComputeAggregate((short) measurement);
i += 3;
}
stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements); stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements);
} }
@ -138,40 +154,22 @@ public class CalculateAverage_flippingbits {
} }
private static class PartitionAggregate { private static class PartitionAggregate {
final short[] lane = new short[SIMD_LANE_LENGTH * 2]; final short[] doubleLane = new short[SIMD_LANE_LENGTH * 2];
// Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition // Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition
int count = 0; int count = 0;
long sum = 0; long sum = 0;
short min = Short.MAX_VALUE; short min = Short.MAX_VALUE;
short max = Short.MIN_VALUE; short max = Short.MIN_VALUE;
public void addMeasurementAndComputeAggregate(byte[] measurementBytes) { public void addMeasurementAndComputeAggregate(short measurement) {
// Parse measurement and exploit that we know the format of the floating-point values
var measurement = measurementBytes[measurementBytes.length - 1] - '0';
var digits = 1;
for (var i = measurementBytes.length - 3; i > 0; i--) {
var num = measurementBytes[i] - '0';
measurement = measurement + (num * (int) Math.pow(10, digits));
digits++;
}
// Check if measurement is negative
if (measurementBytes[0] == '-') {
measurement = measurement * -1;
}
else {
var num = measurementBytes[0] - '0';
measurement = measurement + (num * (int) Math.pow(10, digits));
}
// Add measurement to buffer, which is later processed by SIMD instructions // Add measurement to buffer, which is later processed by SIMD instructions
lane[count % lane.length] = (short) measurement; doubleLane[count % doubleLane.length] = measurement;
count++; count++;
// Once lane is full, use SIMD instructions to calculate aggregates // Once lane is full, use SIMD instructions to calculate aggregates
if (count % lane.length == 0) { if (count % doubleLane.length == 0) {
var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, lane, 0); var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, 0);
var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, lane, SIMD_LANE_LENGTH); var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, SIMD_LANE_LENGTH);
var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN); var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN);
min = (short) Math.min(min, simdMin); min = (short) Math.min(min, simdMin);
@ -184,8 +182,8 @@ public class CalculateAverage_flippingbits {
} }
public void aggregateRemainingMeasurements() { public void aggregateRemainingMeasurements() {
for (var i = 0; i < count % lane.length; i++) { for (var i = 0; i < count % doubleLane.length; i++) {
var measurement = lane[i]; var measurement = doubleLane[i];
min = (short) Math.min(min, measurement); min = (short) Math.min(min, measurement);
max = (short) Math.max(max, measurement); max = (short) Math.max(max, measurement);
sum += measurement; sum += measurement;
@ -204,7 +202,7 @@ public class CalculateAverage_flippingbits {
Locale.US, Locale.US,
"%.1f/%.1f/%.1f", "%.1f/%.1f/%.1f",
(min / 10.0), (min / 10.0),
(sum / 10.0) / count, ((sum / 10.0) / count),
(max / 10.0)); (max / 10.0));
} }
} }