From eaa4050a1b479aaeac3ac2ea1caf8b5da1bbd42d Mon Sep 17 00:00:00 2001 From: Dr Ian Preston Date: Mon, 15 Jan 2024 17:58:23 +0000 Subject: [PATCH] 12s (25%) faster on 4 core i7 (#421) --- .../onebrc/CalculateAverage_ianopolous.java | 120 ++++++++++-------- 1 file changed, 64 insertions(+), 56 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolous.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolous.java index 834de74..4d82d88 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolous.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolous.java @@ -18,85 +18,88 @@ package dev.morling.onebrc; import java.io.*; import java.nio.*; import java.nio.channels.*; +import java.util.concurrent.*; import java.util.stream.*; import java.util.*; -/* A simple implementation that memory maps the file, reads chunks in parallel and minimises allocation without any unsafe. +/* A simple implementation aiming for readability. + * Features: + * * memory mapped file + * * read chunks in parallel + * * minimise allocation + * * no unsafe * * Timings on 4 core i7-7500U CPU @ 2.70GHz: * average_baseline: 4m48s - * ianopolous: 48s + * ianopolous: 36s */ public class CalculateAverage_ianopolous { public static final int MAX_LINE_LENGTH = 107; - public static final int MAX_STATIONS = 10000; + public static final int MAX_STATIONS = 10_000; - public static void main(String[] args) { + public static void main(String[] args) throws Exception { File input = new File("./measurements.txt"); long filesize = input.length(); - long chunkSize = 256 * 1024 * 1024; + // keep chunk size between 256 MB and 1G (1 chunk for files < 256MB) + long chunkSize = Math.min(Math.max(filesize / 32, 256 * 1024 * 1024), 1024 * 1024 * 1024L); int nChunks = (int) ((filesize + chunkSize - 1) / chunkSize); - List> allResults = IntStream.range(0, nChunks).mapToObj(i -> { - HashMap results = new HashMap(512); - parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize), results); - return results; - }).parallel().toList(); - HashMap result = allResults.getFirst(); - for (int i = 1; i < allResults.size(); ++i) { - for (Map.Entry entry : allResults.get(i).entrySet()) { - Stat current = result.putIfAbsent(entry.getKey(), entry.getValue()); - if (current != null) { - current.merge(entry.getValue()); - } - } - } + ExecutorService pool = Executors.newVirtualThreadPerTaskExecutor(); + List>>> allResults = IntStream.range(0, nChunks) + .mapToObj(i -> pool.submit(() -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize)))) + .toList(); - System.out.println(new TreeMap<>(result)); + TreeMap merged = allResults.stream() + .parallel() + .flatMap(f -> { + try { + return f.get().stream().filter(Objects::nonNull).flatMap(Collection::stream); + } + catch (Exception e) { + return Stream.empty(); + } + }) + .collect(Collectors.toMap(s -> s.name(), s -> s, (a, b) -> a.merge(b), TreeMap::new)); + System.out.println(merged); } - public record Station(String name, ByteBuffer buf) { - } - - public static boolean matchingStationBytes(int start, int end, MappedByteBuffer buffer, Station existing) { - buffer.position(start); + public static boolean matchingStationBytes(int start, int end, MappedByteBuffer buffer, Stat existing) { for (int i = start; i < end; i++) { - if (existing.buf.get(i - start) != buffer.get(i)) + if (existing.name[i - start] != buffer.get(i)) return false; } return true; } - public static Station parseStation(int start, int end, int hash, MappedByteBuffer buffer, List> stations) { + public static Stat parseStation(int start, int end, int hash, MappedByteBuffer buffer, List> stations) { int index = Math.floorMod(hash, MAX_STATIONS); - List matches = stations.get(index); + List matches = stations.get(index); if (matches == null) { - List value = new ArrayList<>(); + List value = new ArrayList<>(); byte[] stationBuffer = new byte[end - start]; buffer.position(start); buffer.get(stationBuffer); - String name = new String(stationBuffer); - Station res = new Station(name, ByteBuffer.wrap(stationBuffer)); + Stat res = new Stat(stationBuffer); value.add(res); stations.set(index, value); return res; } else { for (int i = 0; i < matches.size(); i++) { - Station s = matches.get(i); + Stat s = matches.get(i); if (matchingStationBytes(start, end, buffer, s)) return s; } byte[] stationBuffer = new byte[end - start]; buffer.position(start); buffer.get(stationBuffer); - Station res = new Station(new String(stationBuffer), ByteBuffer.wrap(stationBuffer)); + Stat res = new Stat(stationBuffer); matches.add(res); return res; } } - public static void parseStats(long startByte, long endByte, Map results) { + public static List> parseStats(long startByte, long endByte) { try { RandomAccessFile file = new RandomAccessFile("./measurements.txt", "r"); long maxEnd = Math.min(file.length(), endByte + MAX_LINE_LENGTH); @@ -117,30 +120,22 @@ public class CalculateAverage_ianopolous { } } - List> stations = new ArrayList<>(MAX_STATIONS); + List> stations = new ArrayList<>(MAX_STATIONS); for (int i = 0; i < MAX_STATIONS; i++) stations.add(null); int lineStart = done; int lineSplit = 0; - long temperature = 0; + short temperature = 0; int hash = 1; boolean negative = false; while (done < maxDone) { - Station station = null; + Stat station = null; for (int i = done; i < done + MAX_LINE_LENGTH && i < maxEnd; i++) { byte b = buffer.get(i); if (b == '\n') { done = i + 1; - Stat res = results.get(station.name); - temperature = negative ? -temperature : temperature; - if (res != null) { - res.add(temperature); - } - else { - res = new Stat(); - res.add(temperature); - results.put(station.name, res); - } + temperature = negative ? (short) -temperature : temperature; + station.add(temperature); lineStart = done; station = null; hash = 1; @@ -152,17 +147,18 @@ public class CalculateAverage_ianopolous { temperature = 0; negative = false; } - else if (b == '-' && station != null) { + else if (station == null) { + hash = 31 * hash + b; + } + else if (b == '-') { negative = true; } - else if (b != '.' && station != null) { - temperature = temperature * 10 + (b - 0x30); - } - else { - hash = 31 * hash + b; + else if (b != '.') { + temperature = (short) (temperature * 10 + (b - 0x30)); } } } + return stations; } catch (IOException e) { throw new RuntimeException(e); @@ -170,9 +166,16 @@ public class CalculateAverage_ianopolous { } public static class Stat { - long min = Long.MAX_VALUE, max = Long.MIN_VALUE, total = 0, count = 0; + final byte[] name; + int count = 0; + short min = Short.MAX_VALUE, max = Short.MIN_VALUE; + long total = 0; - public void add(long value) { + public Stat(byte[] name) { + this.name = name; + } + + public void add(short value) { if (value < min) min = value; if (value > max) @@ -181,19 +184,24 @@ public class CalculateAverage_ianopolous { count++; } - public void merge(Stat value) { + public Stat merge(Stat value) { if (value.min < min) min = value.min; if (value.max > max) max = value.max; total += value.total; count += value.count; + return this; } private static double round(double value) { return Math.round(value) / 10.0; } + public String name() { + return new String(name); + } + public String toString() { return round((double) min) + "/" + round(((double) total) / count) + "/" + round((double) max); }