Second submission by flippingbits - 50% performance improvement
* feat(flippingbits): Improve parsing of measurement and few cleanups * feat(flippingbits): Reduce chunk size to 10MB * feat(flippingbits): Improve parsing of station names * chore(flippingbits): Remove obsolete import * chore(flippingbits): Few cleanups
This commit is contained in:
		| @@ -20,6 +20,7 @@ import jdk.incubator.vector.VectorOperators; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.io.RandomAccessFile; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.util.*; | ||||
|  | ||||
| /** | ||||
| @@ -33,19 +34,17 @@ public class CalculateAverage_flippingbits { | ||||
|  | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|  | ||||
|     private static final long CHUNK_SIZE = 100 * 1024 * 1024; // 100 MB | ||||
|     private static final long CHUNK_SIZE = 10 * 1024 * 1024; // 10 MB | ||||
|  | ||||
|     private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length(); | ||||
|  | ||||
|     private static final int MAX_STATION_NAME_LENGTH = 100; | ||||
|  | ||||
|     public static void main(String[] args) throws IOException { | ||||
|         try (var file = new RandomAccessFile(FILE, "r")) { | ||||
|             // Calculate chunk boundaries | ||||
|             long[][] chunkBoundaries = getChunkBoundaries(file); | ||||
|             // Process chunks | ||||
|             var result = Arrays.asList(chunkBoundaries).stream() | ||||
|                     .map(chunk -> { | ||||
|         var result = Arrays.asList(getSegments()).stream() | ||||
|                 .map(segment -> { | ||||
|                     try { | ||||
|                             return processChunk(chunk[0], chunk[1]); | ||||
|                         return processSegment(segment[0], segment[1]); | ||||
|                     } | ||||
|                     catch (IOException e) { | ||||
|                         throw new RuntimeException(e); | ||||
| @@ -64,73 +63,90 @@ public class CalculateAverage_flippingbits { | ||||
|                     } | ||||
|                     return firstMap; | ||||
|                 }) | ||||
|                     .map(hashMap -> new TreeMap(hashMap)).get(); | ||||
|                 .map(TreeMap::new).get(); | ||||
|  | ||||
|         System.out.println(result); | ||||
|     } | ||||
|     } | ||||
|  | ||||
|     private static long[][] getChunkBoundaries(RandomAccessFile file) throws IOException { | ||||
|     private static long[][] getSegments() throws IOException { | ||||
|         try (var file = new RandomAccessFile(FILE, "r")) { | ||||
|             var fileSize = file.length(); | ||||
|         // Split file into chunks, so we can work around the size limitation of channels | ||||
|         var chunks = (int) (fileSize / CHUNK_SIZE); | ||||
|             // Split file into segments, so we can work around the size limitation of channels | ||||
|             var numSegments = (int) (fileSize / CHUNK_SIZE); | ||||
|  | ||||
|         long[][] chunkBoundaries = new long[chunks + 1][2]; | ||||
|             var boundaries = new long[numSegments + 1][2]; | ||||
|             var endPointer = 0L; | ||||
|  | ||||
|         for (var i = 0; i <= chunks; i++) { | ||||
|             // Start of chunk | ||||
|             chunkBoundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize); | ||||
|             for (var i = 0; i < numSegments; i++) { | ||||
|                 // Start of segment | ||||
|                 boundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize); | ||||
|  | ||||
|             // Seek end of chunk, limited by the end of the file | ||||
|             file.seek(Math.min(chunkBoundaries[i][0] + CHUNK_SIZE - 1, fileSize)); | ||||
|                 // Seek end of segment, limited by the end of the file | ||||
|                 file.seek(Math.min(boundaries[i][0] + CHUNK_SIZE - 1, fileSize)); | ||||
|  | ||||
|             // Extend chunk until end of line or file | ||||
|             while (true) { | ||||
|                 var character = file.read(); | ||||
|                 if (character == '\n' || character == -1) { | ||||
|                     break; | ||||
|                 } | ||||
|                 // Extend segment until end of line or file | ||||
|                 while (file.read() != '\n') { | ||||
|                 } | ||||
|  | ||||
|             // End of chunk | ||||
|                 // End of segment | ||||
|                 endPointer = file.getFilePointer(); | ||||
|             chunkBoundaries[i][1] = endPointer; | ||||
|                 boundaries[i][1] = endPointer; | ||||
|             } | ||||
|  | ||||
|         return chunkBoundaries; | ||||
|             boundaries[numSegments][0] = Math.max(endPointer, numSegments * CHUNK_SIZE); | ||||
|             boundaries[numSegments][1] = fileSize; | ||||
|  | ||||
|             return boundaries; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static Map<String, PartitionAggregate> processChunk(long startOfChunk, long endOfChunk) | ||||
|     private static Map<String, PartitionAggregate> processSegment(long startOfSegment, long endOfSegment) | ||||
|             throws IOException { | ||||
|         Map<String, PartitionAggregate> stationAggregates = new HashMap<>(10_000); | ||||
|         byte[] byteChunk = new byte[(int) (endOfChunk - startOfChunk)]; | ||||
|         Map<String, PartitionAggregate> stationAggregates = new HashMap<>(50_000); | ||||
|         var byteChunk = new byte[(int) (endOfSegment - startOfSegment)]; | ||||
|         var stationBuffer = new byte[MAX_STATION_NAME_LENGTH]; | ||||
|         try (var file = new RandomAccessFile(FILE, "r")) { | ||||
|             file.seek(startOfChunk); | ||||
|             file.seek(startOfSegment); | ||||
|             file.read(byteChunk); | ||||
|             var i = 0; | ||||
|             while (i < byteChunk.length) { | ||||
|                 final var startPosStation = i; | ||||
|  | ||||
|                 // read station name | ||||
|                 // Station name has at least one byte | ||||
|                 stationBuffer[0] = byteChunk[i]; | ||||
|                 i++; | ||||
|                 // Read station name | ||||
|                 var j = 1; | ||||
|                 while (byteChunk[i] != ';') { | ||||
|                     stationBuffer[j] = byteChunk[i]; | ||||
|                     j++; | ||||
|                     i++; | ||||
|                 } | ||||
|                 var station = new String(Arrays.copyOfRange(byteChunk, startPosStation, i)); | ||||
|                 var station = new String(stationBuffer, 0, j, StandardCharsets.UTF_8); | ||||
|                 i++; | ||||
|  | ||||
|                 // read measurement | ||||
|                 final var startPosMeasurement = i; | ||||
|                 while (byteChunk[i] != '\n') { | ||||
|                 // Read measurement | ||||
|                 var isNegative = byteChunk[i] == '-'; | ||||
|                 var measurement = 0; | ||||
|                 if (isNegative) { | ||||
|                     i++; | ||||
|                     while (byteChunk[i] != '.') { | ||||
|                         measurement = measurement * 10 + byteChunk[i] - '0'; | ||||
|                         i++; | ||||
|                     } | ||||
|  | ||||
|                 var measurement = Arrays.copyOfRange(byteChunk, startPosMeasurement, i); | ||||
|                 var aggregate = stationAggregates.getOrDefault(station, new PartitionAggregate()); | ||||
|                 aggregate.addMeasurementAndComputeAggregate(measurement); | ||||
|                 stationAggregates.put(station, aggregate); | ||||
|                     measurement = (measurement * 10 + byteChunk[i + 1] - '0') * -1; | ||||
|                 } | ||||
|                 else { | ||||
|                     while (byteChunk[i] != '.') { | ||||
|                         measurement = measurement * 10 + byteChunk[i] - '0'; | ||||
|                         i++; | ||||
|                     } | ||||
|                     measurement = measurement * 10 + byteChunk[i + 1] - '0'; | ||||
|                 } | ||||
|  | ||||
|                 // Update aggregate | ||||
|                 var aggregate = stationAggregates.computeIfAbsent(station, x -> new PartitionAggregate()); | ||||
|                 aggregate.addMeasurementAndComputeAggregate((short) measurement); | ||||
|                 i += 3; | ||||
|             } | ||||
|             stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements); | ||||
|         } | ||||
|  | ||||
| @@ -138,40 +154,22 @@ public class CalculateAverage_flippingbits { | ||||
|     } | ||||
|  | ||||
|     private static class PartitionAggregate { | ||||
|         final short[] lane = new short[SIMD_LANE_LENGTH * 2]; | ||||
|         final short[] doubleLane = new short[SIMD_LANE_LENGTH * 2]; | ||||
|         // Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition | ||||
|         int count = 0; | ||||
|         long sum = 0; | ||||
|         short min = Short.MAX_VALUE; | ||||
|         short max = Short.MIN_VALUE; | ||||
|  | ||||
|         public void addMeasurementAndComputeAggregate(byte[] measurementBytes) { | ||||
|             // Parse measurement and exploit that we know the format of the floating-point values | ||||
|             var measurement = measurementBytes[measurementBytes.length - 1] - '0'; | ||||
|             var digits = 1; | ||||
|             for (var i = measurementBytes.length - 3; i > 0; i--) { | ||||
|                 var num = measurementBytes[i] - '0'; | ||||
|                 measurement = measurement + (num * (int) Math.pow(10, digits)); | ||||
|                 digits++; | ||||
|             } | ||||
|  | ||||
|             // Check if measurement is negative | ||||
|             if (measurementBytes[0] == '-') { | ||||
|                 measurement = measurement * -1; | ||||
|             } | ||||
|             else { | ||||
|                 var num = measurementBytes[0] - '0'; | ||||
|                 measurement = measurement + (num * (int) Math.pow(10, digits)); | ||||
|             } | ||||
|  | ||||
|         public void addMeasurementAndComputeAggregate(short measurement) { | ||||
|             // Add measurement to buffer, which is later processed by SIMD instructions | ||||
|             lane[count % lane.length] = (short) measurement; | ||||
|             doubleLane[count % doubleLane.length] = measurement; | ||||
|             count++; | ||||
|  | ||||
|             // Once lane is full, use SIMD instructions to calculate aggregates | ||||
|             if (count % lane.length == 0) { | ||||
|                 var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, lane, 0); | ||||
|                 var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, lane, SIMD_LANE_LENGTH); | ||||
|             if (count % doubleLane.length == 0) { | ||||
|                 var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, 0); | ||||
|                 var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, SIMD_LANE_LENGTH); | ||||
|  | ||||
|                 var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN); | ||||
|                 min = (short) Math.min(min, simdMin); | ||||
| @@ -184,8 +182,8 @@ public class CalculateAverage_flippingbits { | ||||
|         } | ||||
|  | ||||
|         public void aggregateRemainingMeasurements() { | ||||
|             for (var i = 0; i < count % lane.length; i++) { | ||||
|                 var measurement = lane[i]; | ||||
|             for (var i = 0; i < count % doubleLane.length; i++) { | ||||
|                 var measurement = doubleLane[i]; | ||||
|                 min = (short) Math.min(min, measurement); | ||||
|                 max = (short) Math.max(max, measurement); | ||||
|                 sum += measurement; | ||||
| @@ -204,7 +202,7 @@ public class CalculateAverage_flippingbits { | ||||
|                     Locale.US, | ||||
|                     "%.1f/%.1f/%.1f", | ||||
|                     (min / 10.0), | ||||
|                     (sum / 10.0) / count, | ||||
|                     ((sum / 10.0) / count), | ||||
|                     (max / 10.0)); | ||||
|         } | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user