Further improved performance by improving the parsing logic so that strings for city names are not allocated with each row. (#323)
Co-authored-by: Bruno Felix <bruno.felix@klarna.com>
This commit is contained in:
		| @@ -16,17 +16,16 @@ | ||||
| package dev.morling.onebrc; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.io.RandomAccessFile; | ||||
| import java.lang.foreign.Arena; | ||||
| import java.lang.foreign.MemorySegment; | ||||
| import java.lang.foreign.ValueLayout; | ||||
| import java.nio.ByteBuffer; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.nio.file.Paths; | ||||
| import java.util.ArrayList; | ||||
| import java.util.HashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import java.util.TreeMap; | ||||
| import java.util.concurrent.CompletableFuture; | ||||
| import java.util.concurrent.Executors; | ||||
| import java.util.stream.Collectors; | ||||
| @@ -36,6 +35,55 @@ public class CalculateAverage_felix19350 { | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|     private static final int NEW_LINE_SEEK_BUFFER_LEN = 128; | ||||
|  | ||||
|     private static final int EXPECTED_MAX_NUM_CITIES = 15_000; // 10K cities + a buffer no to trigger the load factor | ||||
|  | ||||
|     private static class CityRef { | ||||
|  | ||||
|         final int length; | ||||
|         final int fingerprint; | ||||
|         final byte[] stringBytes; | ||||
|  | ||||
|         public CityRef(ByteBuffer byteBuffer, int startIdx, int length, int fingerprint) { | ||||
|             this.length = length; | ||||
|             this.stringBytes = new byte[length]; | ||||
|             byteBuffer.get(startIdx, this.stringBytes, 0, this.stringBytes.length); | ||||
|             this.fingerprint = fingerprint; | ||||
|         } | ||||
|  | ||||
|         public String cityName() { | ||||
|             return new String(stringBytes, StandardCharsets.UTF_8); | ||||
|         } | ||||
|  | ||||
|         @Override | ||||
|         public int hashCode() { | ||||
|             return fingerprint; | ||||
|         } | ||||
|  | ||||
|         @Override | ||||
|         public boolean equals(Object other) { | ||||
|             if (other instanceof CityRef otherRef) { | ||||
|                 if (fingerprint != otherRef.fingerprint) { | ||||
|                     return false; | ||||
|                 } | ||||
|  | ||||
|                 if (this.length != otherRef.length) { | ||||
|                     return false; | ||||
|                 } | ||||
|  | ||||
|                 for (var i = 0; i < this.length; i++) { | ||||
|                     if (this.stringBytes[i] != otherRef.stringBytes[i]) { | ||||
|                         return false; | ||||
|                     } | ||||
|                 } | ||||
|                 return true; | ||||
|             } | ||||
|             else { | ||||
|                 return false; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|     } | ||||
|  | ||||
|     private static class ResultRow { | ||||
|  | ||||
|         private int min; | ||||
| @@ -73,95 +121,104 @@ public class CalculateAverage_felix19350 { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|   private record AverageAggregatorTask(MemorySegment memSegment) { | ||||
|   private record AverageAggregatorTask(ByteBuffer byteBuffer) { | ||||
|     private static final int HASH_FACTOR = 31; // Mersenne prime | ||||
|  | ||||
|     public static Stream<AverageAggregatorTask> createStreamOf(List<MemorySegment> memorySegments) { | ||||
|       return memorySegments.stream().map(AverageAggregatorTask::new); | ||||
|  | ||||
|     public static Stream<AverageAggregatorTask> createStreamOf(List<ByteBuffer> byteBuffers) { | ||||
|       return byteBuffers.stream().map(AverageAggregatorTask::new); | ||||
|     } | ||||
|  | ||||
|     public Map<String, ResultRow> processChunk() { | ||||
|       final var result = new TreeMap<String, ResultRow>(); | ||||
|       var offset = 0L; | ||||
|       var lineStart = 0L; | ||||
|       while (offset < memSegment.byteSize()) { | ||||
|         byte nextByte = memSegment.get(ValueLayout.OfByte.JAVA_BYTE, offset); | ||||
|         if ((char) nextByte == '\n') { | ||||
|           this.processLine(result, memSegment.asSlice(lineStart, (offset - lineStart)).asByteBuffer()); | ||||
|           lineStart = offset + ValueLayout.JAVA_BYTE.byteSize(); | ||||
|         } | ||||
|         offset += ValueLayout.OfByte.JAVA_BYTE.byteSize(); | ||||
|     public Map<CityRef, ResultRow> processChunk() { | ||||
|       final var measurements = new HashMap<CityRef, ResultRow>(EXPECTED_MAX_NUM_CITIES); | ||||
|       var lineStart = 0; | ||||
|       // process line by line playing with the fact that a line is no longer than 106 bytes | ||||
|       // 100 bytes for city name + 1 byte for separator + 1 bytes for negative sign + 4 bytes for number | ||||
|       while (lineStart < byteBuffer.limit()) { | ||||
|         lineStart = this.processLine(measurements, byteBuffer, lineStart); | ||||
|       } | ||||
|  | ||||
|       return result; | ||||
|       return measurements; | ||||
|     } | ||||
|  | ||||
|     private void processLine(Map<String, ResultRow> result, ByteBuffer lineBytes) { | ||||
|     private int processLine(Map<CityRef, ResultRow> measurements, ByteBuffer byteBuffer, int start) { | ||||
|       var fingerPrint = 0; | ||||
|       var separatorIdx = -1; | ||||
|       for (int i = 0; i < lineBytes.limit(); i++) { | ||||
|         if ((char) lineBytes.get() == ';') { | ||||
|           separatorIdx = i; | ||||
|           lineBytes.clear(); | ||||
|           break; | ||||
|       var sign = 1; | ||||
|       var value = 0; | ||||
|       var lineEnd = -1; | ||||
|       // Lines are processed in two stages: | ||||
|       // 1 - prior do the city name separator | ||||
|       // 2 - after the separator | ||||
|       // this ensures less if clauses | ||||
|  | ||||
|       // stage 1 loop | ||||
|       { | ||||
|         for (int i = 0; i < NEW_LINE_SEEK_BUFFER_LEN; i++) { | ||||
|           final var currentByte = byteBuffer.get(start + i); | ||||
|           if (currentByte == ';') { | ||||
|             separatorIdx = i; | ||||
|             break; | ||||
|           } else { | ||||
|             fingerPrint = HASH_FACTOR * fingerPrint + currentByte; | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       // stage 2 loop: | ||||
|       { | ||||
|         for (int i = separatorIdx + 1; i < NEW_LINE_SEEK_BUFFER_LEN; i++) { | ||||
|           final var currentByte = byteBuffer.get(start + i); | ||||
|           switch (currentByte) { | ||||
|             case '-': | ||||
|               sign = -1; | ||||
|               break; | ||||
|             case '.': | ||||
|               break; | ||||
|             case '\n': | ||||
|               lineEnd = start + i + 1; | ||||
|               break; | ||||
|             default: | ||||
|               // only digits are expected here | ||||
|               value = value * 10 + (currentByte - '0'); | ||||
|           } | ||||
|  | ||||
|           if (lineEnd != -1) { | ||||
|             break; | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       assert (separatorIdx > 0); | ||||
|       final var cityRef = new CityRef(byteBuffer, start, separatorIdx,fingerPrint); | ||||
|       value = sign * value; | ||||
|  | ||||
|       var valueCapacity = lineBytes.capacity() - (separatorIdx + 1); | ||||
|       var cityBytes = new byte[separatorIdx]; | ||||
|       var valueBytes = new byte[valueCapacity]; | ||||
|       lineBytes.get(cityBytes, 0, separatorIdx); | ||||
|       lineBytes.get(separatorIdx + 1, valueBytes); | ||||
|  | ||||
|       var city = new String(cityBytes, StandardCharsets.UTF_8); | ||||
|       var value = parseInt(valueBytes); | ||||
|  | ||||
|       var latestValue = result.get(city); | ||||
|       if (latestValue != null) { | ||||
|         latestValue.mergeValue(value); | ||||
|       final var existingMeasurement = measurements.get(cityRef); | ||||
|       if (existingMeasurement == null) { | ||||
|         measurements.put(cityRef, new ResultRow(value)); | ||||
|       } else { | ||||
|         result.put(city, new ResultRow(value)); | ||||
|         existingMeasurement.mergeValue(value); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     private static int parseInt(byte[] valueBytes) { | ||||
|       int multiplier = 1; | ||||
|       int digitValue = 0; | ||||
|       var numDigits = valueBytes.length-1; // there is always one decimal place | ||||
|       var ds = new int[]{1,10,100}; | ||||
|  | ||||
|       for (byte valueByte : valueBytes) { | ||||
|         switch ((char) valueByte) { | ||||
|           case '-': | ||||
|             multiplier = -1; | ||||
|             numDigits -= 1; | ||||
|             break; | ||||
|           case '.': | ||||
|             break; | ||||
|           default: | ||||
|             digitValue += ((int) valueByte - 48) * (ds[numDigits - 1]); | ||||
|             numDigits -= 1; | ||||
|             break;// TODO continue here | ||||
|         } | ||||
|       } | ||||
|       return multiplier*digitValue; | ||||
|       return lineEnd; //to account for the line end | ||||
|     } | ||||
|   } | ||||
|  | ||||
|     public static void main(String[] args) throws IOException { | ||||
|         // memory map the files and divide by number of cores | ||||
|         var numProcessors = Runtime.getRuntime().availableProcessors(); | ||||
|         var memorySegments = calculateMemorySegments(numProcessors); | ||||
|         var tasks = AverageAggregatorTask.createStreamOf(memorySegments); | ||||
|         assert (memorySegments.size() == numProcessors); | ||||
|         final var numProcessors = Runtime.getRuntime().availableProcessors(); | ||||
|         final var byteBuffers = calculateMemorySegments(numProcessors); | ||||
|         final var tasks = AverageAggregatorTask.createStreamOf(byteBuffers); | ||||
|         assert (byteBuffers.size() <= numProcessors); | ||||
|         assert (!byteBuffers.isEmpty()); | ||||
|  | ||||
|         try (var pool = Executors.newFixedThreadPool(numProcessors)) { | ||||
|             var results = tasks | ||||
|             final Map<CityRef, ResultRow> aggregatedCities = tasks | ||||
|                     .parallel() | ||||
|                     .map(task -> CompletableFuture.supplyAsync(task::processChunk, pool)) | ||||
|                     .map(CompletableFuture::join) | ||||
|                     .reduce(new TreeMap<>(), (partialMap, accumulator) -> { | ||||
|                         partialMap.forEach((key, value) -> { | ||||
|                             var prev = accumulator.get(key); | ||||
|                     .reduce(new HashMap<>(EXPECTED_MAX_NUM_CITIES), (currentMap, accumulator) -> { | ||||
|                         currentMap.forEach((key, value) -> { | ||||
|                             final var prev = accumulator.get(key); | ||||
|                             if (prev == null) { | ||||
|                                 accumulator.put(key, value); | ||||
|                             } | ||||
| @@ -172,6 +229,9 @@ public class CalculateAverage_felix19350 { | ||||
|                         return accumulator; | ||||
|                     }); | ||||
|  | ||||
|             var results = new HashMap<String, ResultRow>(EXPECTED_MAX_NUM_CITIES); | ||||
|             aggregatedCities.forEach((key, value) -> results.put(key.cityName(), value)); | ||||
|  | ||||
|             System.out.print("{"); | ||||
|             String output = results.keySet() | ||||
|                     .stream() | ||||
| @@ -183,16 +243,16 @@ public class CalculateAverage_felix19350 { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static List<MemorySegment> calculateMemorySegments(int numChunks) throws IOException { | ||||
|         try (RandomAccessFile raf = new RandomAccessFile(FILE, "r")) { | ||||
|             var result = new ArrayList<MemorySegment>(numChunks); | ||||
|             var chunks = new ArrayList<long[]>(numChunks); | ||||
|     private static List<ByteBuffer> calculateMemorySegments(int numChunks) throws IOException { | ||||
|         try (FileChannel fc = FileChannel.open(Paths.get(FILE))) { | ||||
|             var memMappedFile = fc.map(FileChannel.MapMode.READ_ONLY, 0L, fc.size(), Arena.ofAuto()); | ||||
|             var result = new ArrayList<ByteBuffer>(numChunks); | ||||
|  | ||||
|             var fileSize = raf.length(); | ||||
|             var chunkSize = fileSize / numChunks; | ||||
|             var fileSize = fc.size(); | ||||
|             var chunkSize = fileSize / numChunks; // TODO: if chunksize > MAX INT we will need to adjust | ||||
|             var previousChunkEnd = 0L; | ||||
|  | ||||
|             for (int i = 0; i < numChunks; i++) { | ||||
|                 var previousChunkEnd = i == 0 ? 0L : chunks.get(i - 1)[1]; | ||||
|                 if (previousChunkEnd >= fileSize) { | ||||
|                     // There is a scenario for very small files where the number of chunks may be greater than | ||||
|                     // the number of lines. | ||||
| @@ -205,31 +265,27 @@ public class CalculateAverage_felix19350 { | ||||
|                 } | ||||
|                 else { | ||||
|                     // all other chunks are end at a new line (\n) | ||||
|                     var theoreticalEnd = previousChunkEnd + chunkSize; | ||||
|                     var buffer = new byte[NEW_LINE_SEEK_BUFFER_LEN]; | ||||
|                     raf.seek(theoreticalEnd); | ||||
|                     raf.read(buffer, 0, NEW_LINE_SEEK_BUFFER_LEN); | ||||
|  | ||||
|                     var theoreticalEnd = Math.min(previousChunkEnd + chunkSize, fileSize); | ||||
|                     var newLineOffset = 0; | ||||
|                     for (byte b : buffer) { | ||||
|                     for (int j = 0; j < NEW_LINE_SEEK_BUFFER_LEN; j++) { | ||||
|                         var candidateOffset = theoreticalEnd + j; | ||||
|                         if (candidateOffset >= fileSize) { | ||||
|                             break; | ||||
|                         } | ||||
|                         byte b = memMappedFile.get(ValueLayout.OfByte.JAVA_BYTE, candidateOffset); | ||||
|                         newLineOffset += 1; | ||||
|                         if ((char) b == '\n') { | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|                     chunk[1] = Math.min(fileSize, theoreticalEnd + newLineOffset); | ||||
|                     previousChunkEnd = chunk[1]; | ||||
|                 } | ||||
|  | ||||
|                 assert (chunk[0] >= 0L); | ||||
|                 assert (chunk[0] <= fileSize); | ||||
|                 assert (chunk[1] > chunk[0]); | ||||
|                 assert (chunk[1] <= fileSize); | ||||
|  | ||||
|                 var memMappedFile = raf.getChannel() | ||||
|                         .map(FileChannel.MapMode.READ_ONLY, chunk[0], (chunk[1] - chunk[0]), Arena.ofAuto()); | ||||
|                 memMappedFile.load(); | ||||
|                 chunks.add(chunk); | ||||
|                 result.add(memMappedFile); | ||||
|                 result.add(memMappedFile.asSlice(chunk[0], (chunk[1] - chunk[0])).asByteBuffer()); | ||||
|             } | ||||
|             return result; | ||||
|         } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user