diff --git a/prepare_phd3.sh b/prepare_phd3.sh new file mode 100755 index 0000000..f83a3ff --- /dev/null +++ b/prepare_phd3.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_phd3.java b/src/main/java/dev/morling/onebrc/CalculateAverage_phd3.java index e3d1cdb..97f832b 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_phd3.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_phd3.java @@ -15,18 +15,24 @@ */ package dev.morling.onebrc; -import static java.nio.charset.StandardCharsets.*; import static java.util.stream.Collectors.*; import java.io.File; import java.io.RandomAccessFile; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.TreeMap; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import java.util.stream.IntStream; public class CalculateAverage_phd3 { @@ -34,12 +40,16 @@ public class CalculateAverage_phd3 { private static final int NUM_THREADS = Runtime.getRuntime().availableProcessors() * 2; private static final String FILE = "./measurements.txt"; private static final long FILE_SIZE = new File(FILE).length(); + // A chunk is a unit for processing, the file will be divided in chunks of the following size private static final int CHUNK_SIZE = 65536 * 1024; + // Read a little more data into the buffer to finish processing current line private static final int PADDING = 512; + // Minor : Precompute powers to avoid recalculating while parsing doubles (temperatures) private static final double[] POWERS_OF_10 = IntStream.range(0, 6).mapToDouble(x -> Math.pow(10.0, x)).toArray(); - private static final Map globalMap = new ConcurrentHashMap<>(); - + /** + * A Utility to print aggregated information in the desired format + */ private record ResultRow(double min, double mean, double max) { public String toString() { @@ -52,7 +62,7 @@ public class CalculateAverage_phd3 { }; public static ResultRow resultRow(AggregationInfo aggregationInfo) { - return new ResultRow(aggregationInfo.min, aggregationInfo.sum / aggregationInfo.count, aggregationInfo.max); + return new ResultRow(aggregationInfo.min, (Math.round(aggregationInfo.sum * 10.0) / 10.0) / (aggregationInfo.count), aggregationInfo.max); } public static void main(String[] args) throws Exception { @@ -60,19 +70,37 @@ public class CalculateAverage_phd3 { int numChunks = (int) Math.ceil(fileLength * 1.0 / CHUNK_SIZE); ExecutorService executorService = Executors.newFixedThreadPool(NUM_THREADS); BufferDataProvider provider = new RandomAccessBasedProvider(FILE, FILE_SIZE); + List> futures = new ArrayList<>(); + // Process chunks in parallel for (int chunkIndex = 0; chunkIndex < numChunks; chunkIndex++) { - executorService.submit(new Aggregator(chunkIndex, provider)); + futures.add(executorService.submit(new Aggregator(chunkIndex, provider))); } executorService.shutdown(); executorService.awaitTermination(10, TimeUnit.MINUTES); - Map measurements = new TreeMap<>(globalMap.entrySet().stream() + Map info = futures.stream().map(f -> { + try { + return f.get(); + } + catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } + }) + .map(LinearProbingHashMap::toMap) + .flatMap(map -> map.entrySet().stream()) + .sequential() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, AggregationInfo::update)); + + Map measurements = new TreeMap<>(info.entrySet().stream() .collect(toMap(Map.Entry::getKey, e -> resultRow(e.getValue())))); System.out.println(measurements); } + /** + * Stores required running aggregation information to be able to compute min/max/average at the end + */ private static class AggregationInfo { double min = Double.POSITIVE_INFINITY; double max = Double.NEGATIVE_INFINITY; @@ -108,13 +136,14 @@ public class CalculateAverage_phd3 { int read(byte[] buffer, long offset) throws Exception; } + /** + * uses RandomAccessFile seek and read APIs to load data into a buffer. + */ private static class RandomAccessBasedProvider implements BufferDataProvider { private final String filePath; - private final long fileSize; RandomAccessBasedProvider(String filePath, long fileSize) { this.filePath = filePath; - this.fileSize = fileSize; } @Override @@ -133,7 +162,10 @@ public class CalculateAverage_phd3 { } } - private static class Aggregator implements Runnable { + /** + * Task to processes a chunk of file and return a custom linear probing hashmap for performance + */ + private static class Aggregator implements Callable { private final long startByte; private final BufferDataProvider dataProvider; @@ -143,7 +175,7 @@ public class CalculateAverage_phd3 { } @Override - public void run() { + public LinearProbingHashMap call() { try { // offset for the last byte to be processed (excluded) long endByte = Math.min(startByte + CHUNK_SIZE, FILE_SIZE); @@ -151,25 +183,15 @@ public class CalculateAverage_phd3 { long bufferSize = endByte - startByte + ((endByte == FILE_SIZE) ? 0 : PADDING); byte[] buffer = new byte[(int) bufferSize]; int bytes = dataProvider.read(buffer, startByte); - // Partial aggregation to avoid accessing global concurrent map for every entry - Map updated = processBuffer( - buffer, startByte == 0, endByte - startByte); - // Full aggregation with global map - updated.entrySet().forEach(entry -> { - globalMap.compute(entry.getKey(), (k, v) -> { - if (v == null) { - return entry.getValue(); - } - return v.update(entry.getValue()); - }); - }); + // Partial aggregation in a hashmap + return processBuffer(buffer, startByte == 0, endByte - startByte); } catch (Throwable e) { throw new RuntimeException(e); } } - private static Map processBuffer(byte[] buffer, boolean isFileStart, long nextChunkStart) { + private static LinearProbingHashMap processBuffer(byte[] buffer, boolean isFileStart, long nextChunkStart) { int start = 0; // Move to the next entry after '\n'. Don't do this if we're at the start of // the file to avoid missing first entry. @@ -180,13 +202,15 @@ public class CalculateAverage_phd3 { start += 1; } - // local map for this thread, don't need thread safety - Map chunkMap = new HashMap<>(); + LinearProbingHashMap chunkLocalMap = new LinearProbingHashMap(); while (true) { LineInfo lineInfo = getNextLine(buffer, start); - String key = new String(buffer, start, lineInfo.semicolonIndex - start); + byte[] keyBytes = new byte[lineInfo.semicolonIndex - start]; + System.arraycopy(buffer, start, keyBytes, 0, keyBytes.length); double value = parseDouble(buffer, lineInfo.semicolonIndex + 1, lineInfo.nextStart - 1); - update(chunkMap, key, value); + // Update aggregated value for the given key with the new line + AggregationInfo info = chunkLocalMap.get(keyBytes, lineInfo.keyHash); + info.update(value); if ((lineInfo.nextStart > nextChunkStart) || (lineInfo.nextStart >= buffer.length)) { // we are already at a point where the next line will be processed in the next chunk, @@ -196,9 +220,12 @@ public class CalculateAverage_phd3 { start = lineInfo.nextStart(); } - return chunkMap; + return chunkLocalMap; } + /** + * Converts bytes to double value without intermediate string conversion, faster than Double.parseDouble. + */ private static double parseDouble(byte[] bytes, int offset, int end) { boolean negative = (bytes[offset] == '-'); int current = negative ? offset + 1 : offset; @@ -216,26 +243,97 @@ public class CalculateAverage_phd3 { return (preFloat + ((postFloat) / POWERS_OF_10[end - postFloatStart])) * (negative ? -1 : 1); } - private static void update(Map state, String key, double value) { - AggregationInfo info = state.computeIfAbsent(key, k -> new AggregationInfo()); - info.update(value); - } - - // identifies indexes of the next ';' and '\n', which will be used to get entry key and value from line + /** + * Identifies indexes of the next ';' and '\n', which will be used to get entry key and value from line. Also + * computes the hash value for the key while iterating. + */ private static LineInfo getNextLine(byte[] buffer, int start) { // caller guarantees that the access is in bounds, so no index check + int hash = 0; while (buffer[start] != ';') { start++; + hash = hash * 31 + buffer[start]; } + // The following is just to further reduce the probability of collisions + hash = hash ^ (hash << 16); int semicolonIndex = start; // caller guarantees that the access is in bounds, so no index check while (buffer[start] != '\n') { start++; } - return new LineInfo(semicolonIndex, start + 1); + return new LineInfo(semicolonIndex, start + 1, hash); } } - private record LineInfo(int semicolonIndex, int nextStart) { + private record LineInfo(int semicolonIndex, int nextStart, int keyHash) { + } + + /** + * A simple map with pre-configured fixed bucket count. With 2^13 buckets and current hash function, seeing 4 + * collisions which is not too bad. Every bucket is implemented with a linked list. The map is NOT thread safe. + */ + private static class LinearProbingHashMap { + private final static int BUCKET_COUNT = 8191; + private final Node[] buckets; + + LinearProbingHashMap() { + this.buckets = new Node[BUCKET_COUNT]; + } + + /** + * Given a key, returns the current value of AggregationInfo. If not present, creates a new empty node at the + * front of the bucket + */ + public AggregationInfo get(byte[] key, int keyHash) { + // find bucket index through bitwise AND, works for bucketCount = (2^p - 1) + int bucketIndex = BUCKET_COUNT & keyHash; + Node current = buckets[bucketIndex]; + while (current != null) { + if (Arrays.equals(current.entry.key(), key)) { + return current.entry.aggregationInfo(); + } + current = current.next; + } + + // Entry does not exist, so add a new node in the linked list + AggregationInfo newInfo = new AggregationInfo(); + KeyValuePair pair = new KeyValuePair(key, keyHash, newInfo); + Node newNode = new Node(pair, buckets[bucketIndex]); + buckets[bucketIndex] = newNode; + return newNode.entry.aggregationInfo(); + } + + /** + * A helper to convert to Java's hash map to build the final aggregation after partial aggregations + */ + private Map toMap() { + Map map = new HashMap<>(); + for (Node bucket : buckets) { + while (bucket != null) { + map.put(new String(bucket.entry.key, StandardCharsets.UTF_8), bucket.entry.aggregationInfo()); + bucket = bucket.next; + } + } + return map; + } + } + + /** + * Linked List node to implement a bucket of custom hash map + */ + private static class Node { + KeyValuePair entry; + Node next; + + public Node(KeyValuePair entry, Node next) { + this.entry = entry; + this.next = next; + } + } + + /** + * a wrapper class to store information needed for storing a measurement information in the hashmap + */ + private record KeyValuePair(byte[] key, int keyHash, AggregationInfo aggregationInfo) { } }