Add improvements (#412)
- custom hashmap - avoid string creation - use graal
This commit is contained in:
		| @@ -15,18 +15,24 @@ | ||||
|  */ | ||||
| package dev.morling.onebrc; | ||||
|  | ||||
| import static java.nio.charset.StandardCharsets.*; | ||||
| import static java.util.stream.Collectors.*; | ||||
|  | ||||
| import java.io.File; | ||||
| import java.io.RandomAccessFile; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.HashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import java.util.TreeMap; | ||||
| import java.util.concurrent.ConcurrentHashMap; | ||||
| import java.util.concurrent.Callable; | ||||
| import java.util.concurrent.ExecutionException; | ||||
| import java.util.concurrent.ExecutorService; | ||||
| import java.util.concurrent.Executors; | ||||
| import java.util.concurrent.Future; | ||||
| import java.util.concurrent.TimeUnit; | ||||
| import java.util.stream.Collectors; | ||||
| import java.util.stream.IntStream; | ||||
|  | ||||
| public class CalculateAverage_phd3 { | ||||
| @@ -34,12 +40,16 @@ public class CalculateAverage_phd3 { | ||||
|     private static final int NUM_THREADS = Runtime.getRuntime().availableProcessors() * 2; | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|     private static final long FILE_SIZE = new File(FILE).length(); | ||||
|     // A chunk is a unit for processing, the file will be divided in chunks of the following size | ||||
|     private static final int CHUNK_SIZE = 65536 * 1024; | ||||
|     // Read a little more data into the buffer to finish processing current line | ||||
|     private static final int PADDING = 512; | ||||
|     // Minor : Precompute powers to avoid recalculating while parsing doubles (temperatures) | ||||
|     private static final double[] POWERS_OF_10 = IntStream.range(0, 6).mapToDouble(x -> Math.pow(10.0, x)).toArray(); | ||||
|  | ||||
|     private static final Map<String, AggregationInfo> globalMap = new ConcurrentHashMap<>(); | ||||
|  | ||||
|     /** | ||||
|      * A Utility to print aggregated information in the desired format | ||||
|      */ | ||||
|     private record ResultRow(double min, double mean, double max) { | ||||
|  | ||||
|         public String toString() { | ||||
| @@ -52,7 +62,7 @@ public class CalculateAverage_phd3 { | ||||
|     }; | ||||
|  | ||||
|     public static ResultRow resultRow(AggregationInfo aggregationInfo) { | ||||
|         return new ResultRow(aggregationInfo.min, aggregationInfo.sum / aggregationInfo.count, aggregationInfo.max); | ||||
|         return new ResultRow(aggregationInfo.min, (Math.round(aggregationInfo.sum * 10.0) / 10.0) / (aggregationInfo.count), aggregationInfo.max); | ||||
|     } | ||||
|  | ||||
|     public static void main(String[] args) throws Exception { | ||||
| @@ -60,19 +70,37 @@ public class CalculateAverage_phd3 { | ||||
|         int numChunks = (int) Math.ceil(fileLength * 1.0 / CHUNK_SIZE); | ||||
|         ExecutorService executorService = Executors.newFixedThreadPool(NUM_THREADS); | ||||
|         BufferDataProvider provider = new RandomAccessBasedProvider(FILE, FILE_SIZE); | ||||
|         List<Future<LinearProbingHashMap>> futures = new ArrayList<>(); | ||||
|         // Process chunks in parallel | ||||
|         for (int chunkIndex = 0; chunkIndex < numChunks; chunkIndex++) { | ||||
|             executorService.submit(new Aggregator(chunkIndex, provider)); | ||||
|             futures.add(executorService.submit(new Aggregator(chunkIndex, provider))); | ||||
|         } | ||||
|  | ||||
|         executorService.shutdown(); | ||||
|         executorService.awaitTermination(10, TimeUnit.MINUTES); | ||||
|  | ||||
|         Map<String, ResultRow> measurements = new TreeMap<>(globalMap.entrySet().stream() | ||||
|         Map<String, AggregationInfo> info = futures.stream().map(f -> { | ||||
|             try { | ||||
|                 return f.get(); | ||||
|             } | ||||
|             catch (ExecutionException | InterruptedException e) { | ||||
|                 throw new RuntimeException(e); | ||||
|             } | ||||
|         }) | ||||
|                 .map(LinearProbingHashMap::toMap) | ||||
|                 .flatMap(map -> map.entrySet().stream()) | ||||
|                 .sequential() | ||||
|                 .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, AggregationInfo::update)); | ||||
|  | ||||
|         Map<String, ResultRow> measurements = new TreeMap<>(info.entrySet().stream() | ||||
|                 .collect(toMap(Map.Entry::getKey, e -> resultRow(e.getValue())))); | ||||
|  | ||||
|         System.out.println(measurements); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Stores required running aggregation information to be able to compute min/max/average at the end | ||||
|      */ | ||||
|     private static class AggregationInfo { | ||||
|         double min = Double.POSITIVE_INFINITY; | ||||
|         double max = Double.NEGATIVE_INFINITY; | ||||
| @@ -108,13 +136,14 @@ public class CalculateAverage_phd3 { | ||||
|         int read(byte[] buffer, long offset) throws Exception; | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * uses RandomAccessFile seek and read APIs to load data into a buffer. | ||||
|      */ | ||||
|     private static class RandomAccessBasedProvider implements BufferDataProvider { | ||||
|         private final String filePath; | ||||
|         private final long fileSize; | ||||
|  | ||||
|         RandomAccessBasedProvider(String filePath, long fileSize) { | ||||
|             this.filePath = filePath; | ||||
|             this.fileSize = fileSize; | ||||
|         } | ||||
|  | ||||
|         @Override | ||||
| @@ -133,7 +162,10 @@ public class CalculateAverage_phd3 { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static class Aggregator implements Runnable { | ||||
|     /** | ||||
|      * Task to processes a chunk of file and return a custom linear probing hashmap for performance | ||||
|      */ | ||||
|     private static class Aggregator implements Callable<LinearProbingHashMap> { | ||||
|         private final long startByte; | ||||
|         private final BufferDataProvider dataProvider; | ||||
|  | ||||
| @@ -143,7 +175,7 @@ public class CalculateAverage_phd3 { | ||||
|         } | ||||
|  | ||||
|         @Override | ||||
|         public void run() { | ||||
|         public LinearProbingHashMap call() { | ||||
|             try { | ||||
|                 // offset for the last byte to be processed (excluded) | ||||
|                 long endByte = Math.min(startByte + CHUNK_SIZE, FILE_SIZE); | ||||
| @@ -151,25 +183,15 @@ public class CalculateAverage_phd3 { | ||||
|                 long bufferSize = endByte - startByte + ((endByte == FILE_SIZE) ? 0 : PADDING); | ||||
|                 byte[] buffer = new byte[(int) bufferSize]; | ||||
|                 int bytes = dataProvider.read(buffer, startByte); | ||||
|                 // Partial aggregation to avoid accessing global concurrent map for every entry | ||||
|                 Map<String, AggregationInfo> updated = processBuffer( | ||||
|                         buffer, startByte == 0, endByte - startByte); | ||||
|                 // Full aggregation with global map | ||||
|                 updated.entrySet().forEach(entry -> { | ||||
|                     globalMap.compute(entry.getKey(), (k, v) -> { | ||||
|                         if (v == null) { | ||||
|                             return entry.getValue(); | ||||
|                         } | ||||
|                         return v.update(entry.getValue()); | ||||
|                     }); | ||||
|                 }); | ||||
|                 // Partial aggregation in a hashmap | ||||
|                 return processBuffer(buffer, startByte == 0, endByte - startByte); | ||||
|             } | ||||
|             catch (Throwable e) { | ||||
|                 throw new RuntimeException(e); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private static Map<String, AggregationInfo> processBuffer(byte[] buffer, boolean isFileStart, long nextChunkStart) { | ||||
|         private static LinearProbingHashMap processBuffer(byte[] buffer, boolean isFileStart, long nextChunkStart) { | ||||
|             int start = 0; | ||||
|             // Move to the next entry after '\n'. Don't do this if we're at the start of | ||||
|             // the file to avoid missing first entry. | ||||
| @@ -180,13 +202,15 @@ public class CalculateAverage_phd3 { | ||||
|                 start += 1; | ||||
|             } | ||||
|  | ||||
|             // local map for this thread, don't need thread safety | ||||
|             Map<String, AggregationInfo> chunkMap = new HashMap<>(); | ||||
|             LinearProbingHashMap chunkLocalMap = new LinearProbingHashMap(); | ||||
|             while (true) { | ||||
|                 LineInfo lineInfo = getNextLine(buffer, start); | ||||
|                 String key = new String(buffer, start, lineInfo.semicolonIndex - start); | ||||
|                 byte[] keyBytes = new byte[lineInfo.semicolonIndex - start]; | ||||
|                 System.arraycopy(buffer, start, keyBytes, 0, keyBytes.length); | ||||
|                 double value = parseDouble(buffer, lineInfo.semicolonIndex + 1, lineInfo.nextStart - 1); | ||||
|                 update(chunkMap, key, value); | ||||
|                 // Update aggregated value for the given key with the new line | ||||
|                 AggregationInfo info = chunkLocalMap.get(keyBytes, lineInfo.keyHash); | ||||
|                 info.update(value); | ||||
|  | ||||
|                 if ((lineInfo.nextStart > nextChunkStart) || (lineInfo.nextStart >= buffer.length)) { | ||||
|                     // we are already at a point where the next line will be processed in the next chunk, | ||||
| @@ -196,9 +220,12 @@ public class CalculateAverage_phd3 { | ||||
|  | ||||
|                 start = lineInfo.nextStart(); | ||||
|             } | ||||
|             return chunkMap; | ||||
|             return chunkLocalMap; | ||||
|         } | ||||
|  | ||||
|         /** | ||||
|          * Converts bytes to double value without intermediate string conversion, faster than Double.parseDouble. | ||||
|          */ | ||||
|         private static double parseDouble(byte[] bytes, int offset, int end) { | ||||
|             boolean negative = (bytes[offset] == '-'); | ||||
|             int current = negative ? offset + 1 : offset; | ||||
| @@ -216,26 +243,97 @@ public class CalculateAverage_phd3 { | ||||
|             return (preFloat + ((postFloat) / POWERS_OF_10[end - postFloatStart])) * (negative ? -1 : 1); | ||||
|         } | ||||
|  | ||||
|         private static void update(Map<String, AggregationInfo> state, String key, double value) { | ||||
|             AggregationInfo info = state.computeIfAbsent(key, k -> new AggregationInfo()); | ||||
|             info.update(value); | ||||
|         } | ||||
|  | ||||
|         // identifies indexes of the next ';' and '\n', which will be used to get entry key and value from line | ||||
|         /** | ||||
|          * Identifies indexes of the next ';' and '\n', which will be used to get entry key and value from line. Also | ||||
|          * computes the hash value for the key while iterating. | ||||
|          */ | ||||
|         private static LineInfo getNextLine(byte[] buffer, int start) { | ||||
|             // caller guarantees that the access is in bounds, so no index check | ||||
|             int hash = 0; | ||||
|             while (buffer[start] != ';') { | ||||
|                 start++; | ||||
|                 hash = hash * 31 + buffer[start]; | ||||
|             } | ||||
|             // The following is just to further reduce the probability of collisions | ||||
|             hash = hash ^ (hash << 16); | ||||
|             int semicolonIndex = start; | ||||
|             // caller guarantees that the access is in bounds, so no index check | ||||
|             while (buffer[start] != '\n') { | ||||
|                 start++; | ||||
|             } | ||||
|             return new LineInfo(semicolonIndex, start + 1); | ||||
|             return new LineInfo(semicolonIndex, start + 1, hash); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private record LineInfo(int semicolonIndex, int nextStart) { | ||||
|     private record LineInfo(int semicolonIndex, int nextStart, int keyHash) { | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * A simple map with pre-configured fixed bucket count. With 2^13 buckets and current hash function, seeing 4 | ||||
|      * collisions which is not too bad. Every bucket is implemented with a linked list. The map is NOT thread safe. | ||||
|      */ | ||||
|     private static class LinearProbingHashMap { | ||||
|         private final static int BUCKET_COUNT = 8191; | ||||
|         private final Node[] buckets; | ||||
|  | ||||
|         LinearProbingHashMap() { | ||||
|             this.buckets = new Node[BUCKET_COUNT]; | ||||
|         } | ||||
|  | ||||
|         /** | ||||
|          * Given a key, returns the current value of AggregationInfo. If not present, creates a new empty node at the | ||||
|          * front of the bucket | ||||
|          */ | ||||
|         public AggregationInfo get(byte[] key, int keyHash) { | ||||
|             // find bucket index through bitwise AND, works for bucketCount = (2^p - 1) | ||||
|             int bucketIndex = BUCKET_COUNT & keyHash; | ||||
|             Node current = buckets[bucketIndex]; | ||||
|             while (current != null) { | ||||
|                 if (Arrays.equals(current.entry.key(), key)) { | ||||
|                     return current.entry.aggregationInfo(); | ||||
|                 } | ||||
|                 current = current.next; | ||||
|             } | ||||
|  | ||||
|             // Entry does not exist, so add a new node in the linked list | ||||
|             AggregationInfo newInfo = new AggregationInfo(); | ||||
|             KeyValuePair pair = new KeyValuePair(key, keyHash, newInfo); | ||||
|             Node newNode = new Node(pair, buckets[bucketIndex]); | ||||
|             buckets[bucketIndex] = newNode; | ||||
|             return newNode.entry.aggregationInfo(); | ||||
|         } | ||||
|  | ||||
|         /** | ||||
|          * A helper to convert to Java's hash map to build the final aggregation after partial aggregations | ||||
|          */ | ||||
|         private Map<String, AggregationInfo> toMap() { | ||||
|             Map<String, AggregationInfo> map = new HashMap<>(); | ||||
|             for (Node bucket : buckets) { | ||||
|                 while (bucket != null) { | ||||
|                     map.put(new String(bucket.entry.key, StandardCharsets.UTF_8), bucket.entry.aggregationInfo()); | ||||
|                     bucket = bucket.next; | ||||
|                 } | ||||
|             } | ||||
|             return map; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Linked List node to implement a bucket of custom hash map | ||||
|      */ | ||||
|     private static class Node { | ||||
|         KeyValuePair entry; | ||||
|         Node next; | ||||
|  | ||||
|         public Node(KeyValuePair entry, Node next) { | ||||
|             this.entry = entry; | ||||
|             this.next = next; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * a wrapper class to store information needed for storing a measurement information in the hashmap | ||||
|      */ | ||||
|     private record KeyValuePair(byte[] key, int keyHash, AggregationInfo aggregationInfo) { | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user