Add improvements (#412)
- custom hashmap - avoid string creation - use graal
This commit is contained in:
parent
cd0e20b304
commit
6c7012a43e
19
prepare_phd3.sh
Executable file
19
prepare_phd3.sh
Executable file
@ -0,0 +1,19 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Copyright 2023 The original authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
source "$HOME/.sdkman/bin/sdkman-init.sh"
|
||||||
|
sdk use java 21.0.1-graal 1>&2
|
@ -15,18 +15,24 @@
|
|||||||
*/
|
*/
|
||||||
package dev.morling.onebrc;
|
package dev.morling.onebrc;
|
||||||
|
|
||||||
import static java.nio.charset.StandardCharsets.*;
|
|
||||||
import static java.util.stream.Collectors.*;
|
import static java.util.stream.Collectors.*;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.RandomAccessFile;
|
import java.io.RandomAccessFile;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.Callable;
|
||||||
|
import java.util.concurrent.ExecutionException;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.Future;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.IntStream;
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
public class CalculateAverage_phd3 {
|
public class CalculateAverage_phd3 {
|
||||||
@ -34,12 +40,16 @@ public class CalculateAverage_phd3 {
|
|||||||
private static final int NUM_THREADS = Runtime.getRuntime().availableProcessors() * 2;
|
private static final int NUM_THREADS = Runtime.getRuntime().availableProcessors() * 2;
|
||||||
private static final String FILE = "./measurements.txt";
|
private static final String FILE = "./measurements.txt";
|
||||||
private static final long FILE_SIZE = new File(FILE).length();
|
private static final long FILE_SIZE = new File(FILE).length();
|
||||||
|
// A chunk is a unit for processing, the file will be divided in chunks of the following size
|
||||||
private static final int CHUNK_SIZE = 65536 * 1024;
|
private static final int CHUNK_SIZE = 65536 * 1024;
|
||||||
|
// Read a little more data into the buffer to finish processing current line
|
||||||
private static final int PADDING = 512;
|
private static final int PADDING = 512;
|
||||||
|
// Minor : Precompute powers to avoid recalculating while parsing doubles (temperatures)
|
||||||
private static final double[] POWERS_OF_10 = IntStream.range(0, 6).mapToDouble(x -> Math.pow(10.0, x)).toArray();
|
private static final double[] POWERS_OF_10 = IntStream.range(0, 6).mapToDouble(x -> Math.pow(10.0, x)).toArray();
|
||||||
|
|
||||||
private static final Map<String, AggregationInfo> globalMap = new ConcurrentHashMap<>();
|
/**
|
||||||
|
* A Utility to print aggregated information in the desired format
|
||||||
|
*/
|
||||||
private record ResultRow(double min, double mean, double max) {
|
private record ResultRow(double min, double mean, double max) {
|
||||||
|
|
||||||
public String toString() {
|
public String toString() {
|
||||||
@ -52,7 +62,7 @@ public class CalculateAverage_phd3 {
|
|||||||
};
|
};
|
||||||
|
|
||||||
public static ResultRow resultRow(AggregationInfo aggregationInfo) {
|
public static ResultRow resultRow(AggregationInfo aggregationInfo) {
|
||||||
return new ResultRow(aggregationInfo.min, aggregationInfo.sum / aggregationInfo.count, aggregationInfo.max);
|
return new ResultRow(aggregationInfo.min, (Math.round(aggregationInfo.sum * 10.0) / 10.0) / (aggregationInfo.count), aggregationInfo.max);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
@ -60,19 +70,37 @@ public class CalculateAverage_phd3 {
|
|||||||
int numChunks = (int) Math.ceil(fileLength * 1.0 / CHUNK_SIZE);
|
int numChunks = (int) Math.ceil(fileLength * 1.0 / CHUNK_SIZE);
|
||||||
ExecutorService executorService = Executors.newFixedThreadPool(NUM_THREADS);
|
ExecutorService executorService = Executors.newFixedThreadPool(NUM_THREADS);
|
||||||
BufferDataProvider provider = new RandomAccessBasedProvider(FILE, FILE_SIZE);
|
BufferDataProvider provider = new RandomAccessBasedProvider(FILE, FILE_SIZE);
|
||||||
|
List<Future<LinearProbingHashMap>> futures = new ArrayList<>();
|
||||||
|
// Process chunks in parallel
|
||||||
for (int chunkIndex = 0; chunkIndex < numChunks; chunkIndex++) {
|
for (int chunkIndex = 0; chunkIndex < numChunks; chunkIndex++) {
|
||||||
executorService.submit(new Aggregator(chunkIndex, provider));
|
futures.add(executorService.submit(new Aggregator(chunkIndex, provider)));
|
||||||
}
|
}
|
||||||
|
|
||||||
executorService.shutdown();
|
executorService.shutdown();
|
||||||
executorService.awaitTermination(10, TimeUnit.MINUTES);
|
executorService.awaitTermination(10, TimeUnit.MINUTES);
|
||||||
|
|
||||||
Map<String, ResultRow> measurements = new TreeMap<>(globalMap.entrySet().stream()
|
Map<String, AggregationInfo> info = futures.stream().map(f -> {
|
||||||
|
try {
|
||||||
|
return f.get();
|
||||||
|
}
|
||||||
|
catch (ExecutionException | InterruptedException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map(LinearProbingHashMap::toMap)
|
||||||
|
.flatMap(map -> map.entrySet().stream())
|
||||||
|
.sequential()
|
||||||
|
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, AggregationInfo::update));
|
||||||
|
|
||||||
|
Map<String, ResultRow> measurements = new TreeMap<>(info.entrySet().stream()
|
||||||
.collect(toMap(Map.Entry::getKey, e -> resultRow(e.getValue()))));
|
.collect(toMap(Map.Entry::getKey, e -> resultRow(e.getValue()))));
|
||||||
|
|
||||||
System.out.println(measurements);
|
System.out.println(measurements);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stores required running aggregation information to be able to compute min/max/average at the end
|
||||||
|
*/
|
||||||
private static class AggregationInfo {
|
private static class AggregationInfo {
|
||||||
double min = Double.POSITIVE_INFINITY;
|
double min = Double.POSITIVE_INFINITY;
|
||||||
double max = Double.NEGATIVE_INFINITY;
|
double max = Double.NEGATIVE_INFINITY;
|
||||||
@ -108,13 +136,14 @@ public class CalculateAverage_phd3 {
|
|||||||
int read(byte[] buffer, long offset) throws Exception;
|
int read(byte[] buffer, long offset) throws Exception;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* uses RandomAccessFile seek and read APIs to load data into a buffer.
|
||||||
|
*/
|
||||||
private static class RandomAccessBasedProvider implements BufferDataProvider {
|
private static class RandomAccessBasedProvider implements BufferDataProvider {
|
||||||
private final String filePath;
|
private final String filePath;
|
||||||
private final long fileSize;
|
|
||||||
|
|
||||||
RandomAccessBasedProvider(String filePath, long fileSize) {
|
RandomAccessBasedProvider(String filePath, long fileSize) {
|
||||||
this.filePath = filePath;
|
this.filePath = filePath;
|
||||||
this.fileSize = fileSize;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -133,7 +162,10 @@ public class CalculateAverage_phd3 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class Aggregator implements Runnable {
|
/**
|
||||||
|
* Task to processes a chunk of file and return a custom linear probing hashmap for performance
|
||||||
|
*/
|
||||||
|
private static class Aggregator implements Callable<LinearProbingHashMap> {
|
||||||
private final long startByte;
|
private final long startByte;
|
||||||
private final BufferDataProvider dataProvider;
|
private final BufferDataProvider dataProvider;
|
||||||
|
|
||||||
@ -143,7 +175,7 @@ public class CalculateAverage_phd3 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public LinearProbingHashMap call() {
|
||||||
try {
|
try {
|
||||||
// offset for the last byte to be processed (excluded)
|
// offset for the last byte to be processed (excluded)
|
||||||
long endByte = Math.min(startByte + CHUNK_SIZE, FILE_SIZE);
|
long endByte = Math.min(startByte + CHUNK_SIZE, FILE_SIZE);
|
||||||
@ -151,25 +183,15 @@ public class CalculateAverage_phd3 {
|
|||||||
long bufferSize = endByte - startByte + ((endByte == FILE_SIZE) ? 0 : PADDING);
|
long bufferSize = endByte - startByte + ((endByte == FILE_SIZE) ? 0 : PADDING);
|
||||||
byte[] buffer = new byte[(int) bufferSize];
|
byte[] buffer = new byte[(int) bufferSize];
|
||||||
int bytes = dataProvider.read(buffer, startByte);
|
int bytes = dataProvider.read(buffer, startByte);
|
||||||
// Partial aggregation to avoid accessing global concurrent map for every entry
|
// Partial aggregation in a hashmap
|
||||||
Map<String, AggregationInfo> updated = processBuffer(
|
return processBuffer(buffer, startByte == 0, endByte - startByte);
|
||||||
buffer, startByte == 0, endByte - startByte);
|
|
||||||
// Full aggregation with global map
|
|
||||||
updated.entrySet().forEach(entry -> {
|
|
||||||
globalMap.compute(entry.getKey(), (k, v) -> {
|
|
||||||
if (v == null) {
|
|
||||||
return entry.getValue();
|
|
||||||
}
|
|
||||||
return v.update(entry.getValue());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
catch (Throwable e) {
|
catch (Throwable e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, AggregationInfo> processBuffer(byte[] buffer, boolean isFileStart, long nextChunkStart) {
|
private static LinearProbingHashMap processBuffer(byte[] buffer, boolean isFileStart, long nextChunkStart) {
|
||||||
int start = 0;
|
int start = 0;
|
||||||
// Move to the next entry after '\n'. Don't do this if we're at the start of
|
// Move to the next entry after '\n'. Don't do this if we're at the start of
|
||||||
// the file to avoid missing first entry.
|
// the file to avoid missing first entry.
|
||||||
@ -180,13 +202,15 @@ public class CalculateAverage_phd3 {
|
|||||||
start += 1;
|
start += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// local map for this thread, don't need thread safety
|
LinearProbingHashMap chunkLocalMap = new LinearProbingHashMap();
|
||||||
Map<String, AggregationInfo> chunkMap = new HashMap<>();
|
|
||||||
while (true) {
|
while (true) {
|
||||||
LineInfo lineInfo = getNextLine(buffer, start);
|
LineInfo lineInfo = getNextLine(buffer, start);
|
||||||
String key = new String(buffer, start, lineInfo.semicolonIndex - start);
|
byte[] keyBytes = new byte[lineInfo.semicolonIndex - start];
|
||||||
|
System.arraycopy(buffer, start, keyBytes, 0, keyBytes.length);
|
||||||
double value = parseDouble(buffer, lineInfo.semicolonIndex + 1, lineInfo.nextStart - 1);
|
double value = parseDouble(buffer, lineInfo.semicolonIndex + 1, lineInfo.nextStart - 1);
|
||||||
update(chunkMap, key, value);
|
// Update aggregated value for the given key with the new line
|
||||||
|
AggregationInfo info = chunkLocalMap.get(keyBytes, lineInfo.keyHash);
|
||||||
|
info.update(value);
|
||||||
|
|
||||||
if ((lineInfo.nextStart > nextChunkStart) || (lineInfo.nextStart >= buffer.length)) {
|
if ((lineInfo.nextStart > nextChunkStart) || (lineInfo.nextStart >= buffer.length)) {
|
||||||
// we are already at a point where the next line will be processed in the next chunk,
|
// we are already at a point where the next line will be processed in the next chunk,
|
||||||
@ -196,9 +220,12 @@ public class CalculateAverage_phd3 {
|
|||||||
|
|
||||||
start = lineInfo.nextStart();
|
start = lineInfo.nextStart();
|
||||||
}
|
}
|
||||||
return chunkMap;
|
return chunkLocalMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts bytes to double value without intermediate string conversion, faster than Double.parseDouble.
|
||||||
|
*/
|
||||||
private static double parseDouble(byte[] bytes, int offset, int end) {
|
private static double parseDouble(byte[] bytes, int offset, int end) {
|
||||||
boolean negative = (bytes[offset] == '-');
|
boolean negative = (bytes[offset] == '-');
|
||||||
int current = negative ? offset + 1 : offset;
|
int current = negative ? offset + 1 : offset;
|
||||||
@ -216,26 +243,97 @@ public class CalculateAverage_phd3 {
|
|||||||
return (preFloat + ((postFloat) / POWERS_OF_10[end - postFloatStart])) * (negative ? -1 : 1);
|
return (preFloat + ((postFloat) / POWERS_OF_10[end - postFloatStart])) * (negative ? -1 : 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void update(Map<String, AggregationInfo> state, String key, double value) {
|
/**
|
||||||
AggregationInfo info = state.computeIfAbsent(key, k -> new AggregationInfo());
|
* Identifies indexes of the next ';' and '\n', which will be used to get entry key and value from line. Also
|
||||||
info.update(value);
|
* computes the hash value for the key while iterating.
|
||||||
}
|
*/
|
||||||
|
|
||||||
// identifies indexes of the next ';' and '\n', which will be used to get entry key and value from line
|
|
||||||
private static LineInfo getNextLine(byte[] buffer, int start) {
|
private static LineInfo getNextLine(byte[] buffer, int start) {
|
||||||
// caller guarantees that the access is in bounds, so no index check
|
// caller guarantees that the access is in bounds, so no index check
|
||||||
|
int hash = 0;
|
||||||
while (buffer[start] != ';') {
|
while (buffer[start] != ';') {
|
||||||
start++;
|
start++;
|
||||||
|
hash = hash * 31 + buffer[start];
|
||||||
}
|
}
|
||||||
|
// The following is just to further reduce the probability of collisions
|
||||||
|
hash = hash ^ (hash << 16);
|
||||||
int semicolonIndex = start;
|
int semicolonIndex = start;
|
||||||
// caller guarantees that the access is in bounds, so no index check
|
// caller guarantees that the access is in bounds, so no index check
|
||||||
while (buffer[start] != '\n') {
|
while (buffer[start] != '\n') {
|
||||||
start++;
|
start++;
|
||||||
}
|
}
|
||||||
return new LineInfo(semicolonIndex, start + 1);
|
return new LineInfo(semicolonIndex, start + 1, hash);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private record LineInfo(int semicolonIndex, int nextStart) {
|
private record LineInfo(int semicolonIndex, int nextStart, int keyHash) {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple map with pre-configured fixed bucket count. With 2^13 buckets and current hash function, seeing 4
|
||||||
|
* collisions which is not too bad. Every bucket is implemented with a linked list. The map is NOT thread safe.
|
||||||
|
*/
|
||||||
|
private static class LinearProbingHashMap {
|
||||||
|
private final static int BUCKET_COUNT = 8191;
|
||||||
|
private final Node[] buckets;
|
||||||
|
|
||||||
|
LinearProbingHashMap() {
|
||||||
|
this.buckets = new Node[BUCKET_COUNT];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a key, returns the current value of AggregationInfo. If not present, creates a new empty node at the
|
||||||
|
* front of the bucket
|
||||||
|
*/
|
||||||
|
public AggregationInfo get(byte[] key, int keyHash) {
|
||||||
|
// find bucket index through bitwise AND, works for bucketCount = (2^p - 1)
|
||||||
|
int bucketIndex = BUCKET_COUNT & keyHash;
|
||||||
|
Node current = buckets[bucketIndex];
|
||||||
|
while (current != null) {
|
||||||
|
if (Arrays.equals(current.entry.key(), key)) {
|
||||||
|
return current.entry.aggregationInfo();
|
||||||
|
}
|
||||||
|
current = current.next;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Entry does not exist, so add a new node in the linked list
|
||||||
|
AggregationInfo newInfo = new AggregationInfo();
|
||||||
|
KeyValuePair pair = new KeyValuePair(key, keyHash, newInfo);
|
||||||
|
Node newNode = new Node(pair, buckets[bucketIndex]);
|
||||||
|
buckets[bucketIndex] = newNode;
|
||||||
|
return newNode.entry.aggregationInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A helper to convert to Java's hash map to build the final aggregation after partial aggregations
|
||||||
|
*/
|
||||||
|
private Map<String, AggregationInfo> toMap() {
|
||||||
|
Map<String, AggregationInfo> map = new HashMap<>();
|
||||||
|
for (Node bucket : buckets) {
|
||||||
|
while (bucket != null) {
|
||||||
|
map.put(new String(bucket.entry.key, StandardCharsets.UTF_8), bucket.entry.aggregationInfo());
|
||||||
|
bucket = bucket.next;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Linked List node to implement a bucket of custom hash map
|
||||||
|
*/
|
||||||
|
private static class Node {
|
||||||
|
KeyValuePair entry;
|
||||||
|
Node next;
|
||||||
|
|
||||||
|
public Node(KeyValuePair entry, Node next) {
|
||||||
|
this.entry = entry;
|
||||||
|
this.next = next;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* a wrapper class to store information needed for storing a measurement information in the hashmap
|
||||||
|
*/
|
||||||
|
private record KeyValuePair(byte[] key, int keyHash, AggregationInfo aggregationInfo) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user