12s (25%) faster on 4 core i7 (#421)

This commit is contained in:
Dr Ian Preston 2024-01-15 17:58:23 +00:00 committed by GitHub
parent dbdd89a847
commit eaa4050a1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,85 +18,88 @@ package dev.morling.onebrc;
import java.io.*; import java.io.*;
import java.nio.*; import java.nio.*;
import java.nio.channels.*; import java.nio.channels.*;
import java.util.concurrent.*;
import java.util.stream.*; import java.util.stream.*;
import java.util.*; import java.util.*;
/* A simple implementation that memory maps the file, reads chunks in parallel and minimises allocation without any unsafe. /* A simple implementation aiming for readability.
* Features:
* * memory mapped file
* * read chunks in parallel
* * minimise allocation
* * no unsafe
* *
* Timings on 4 core i7-7500U CPU @ 2.70GHz: * Timings on 4 core i7-7500U CPU @ 2.70GHz:
* average_baseline: 4m48s * average_baseline: 4m48s
* ianopolous: 48s * ianopolous: 36s
*/ */
public class CalculateAverage_ianopolous { public class CalculateAverage_ianopolous {
public static final int MAX_LINE_LENGTH = 107; public static final int MAX_LINE_LENGTH = 107;
public static final int MAX_STATIONS = 10000; public static final int MAX_STATIONS = 10_000;
public static void main(String[] args) { public static void main(String[] args) throws Exception {
File input = new File("./measurements.txt"); File input = new File("./measurements.txt");
long filesize = input.length(); long filesize = input.length();
long chunkSize = 256 * 1024 * 1024; // keep chunk size between 256 MB and 1G (1 chunk for files < 256MB)
long chunkSize = Math.min(Math.max(filesize / 32, 256 * 1024 * 1024), 1024 * 1024 * 1024L);
int nChunks = (int) ((filesize + chunkSize - 1) / chunkSize); int nChunks = (int) ((filesize + chunkSize - 1) / chunkSize);
List<HashMap<String, Stat>> allResults = IntStream.range(0, nChunks).mapToObj(i -> { ExecutorService pool = Executors.newVirtualThreadPerTaskExecutor();
HashMap<String, Stat> results = new HashMap(512); List<Future<List<List<Stat>>>> allResults = IntStream.range(0, nChunks)
parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize), results); .mapToObj(i -> pool.submit(() -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize))))
return results; .toList();
}).parallel().toList();
HashMap<String, Stat> result = allResults.getFirst(); TreeMap<String, Stat> merged = allResults.stream()
for (int i = 1; i < allResults.size(); ++i) { .parallel()
for (Map.Entry<String, Stat> entry : allResults.get(i).entrySet()) { .flatMap(f -> {
Stat current = result.putIfAbsent(entry.getKey(), entry.getValue()); try {
if (current != null) { return f.get().stream().filter(Objects::nonNull).flatMap(Collection::stream);
current.merge(entry.getValue());
} }
catch (Exception e) {
return Stream.empty();
} }
})
.collect(Collectors.toMap(s -> s.name(), s -> s, (a, b) -> a.merge(b), TreeMap::new));
System.out.println(merged);
} }
System.out.println(new TreeMap<>(result)); public static boolean matchingStationBytes(int start, int end, MappedByteBuffer buffer, Stat existing) {
}
public record Station(String name, ByteBuffer buf) {
}
public static boolean matchingStationBytes(int start, int end, MappedByteBuffer buffer, Station existing) {
buffer.position(start);
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
if (existing.buf.get(i - start) != buffer.get(i)) if (existing.name[i - start] != buffer.get(i))
return false; return false;
} }
return true; return true;
} }
public static Station parseStation(int start, int end, int hash, MappedByteBuffer buffer, List<List<Station>> stations) { public static Stat parseStation(int start, int end, int hash, MappedByteBuffer buffer, List<List<Stat>> stations) {
int index = Math.floorMod(hash, MAX_STATIONS); int index = Math.floorMod(hash, MAX_STATIONS);
List<Station> matches = stations.get(index); List<Stat> matches = stations.get(index);
if (matches == null) { if (matches == null) {
List<Station> value = new ArrayList<>(); List<Stat> value = new ArrayList<>();
byte[] stationBuffer = new byte[end - start]; byte[] stationBuffer = new byte[end - start];
buffer.position(start); buffer.position(start);
buffer.get(stationBuffer); buffer.get(stationBuffer);
String name = new String(stationBuffer); Stat res = new Stat(stationBuffer);
Station res = new Station(name, ByteBuffer.wrap(stationBuffer));
value.add(res); value.add(res);
stations.set(index, value); stations.set(index, value);
return res; return res;
} }
else { else {
for (int i = 0; i < matches.size(); i++) { for (int i = 0; i < matches.size(); i++) {
Station s = matches.get(i); Stat s = matches.get(i);
if (matchingStationBytes(start, end, buffer, s)) if (matchingStationBytes(start, end, buffer, s))
return s; return s;
} }
byte[] stationBuffer = new byte[end - start]; byte[] stationBuffer = new byte[end - start];
buffer.position(start); buffer.position(start);
buffer.get(stationBuffer); buffer.get(stationBuffer);
Station res = new Station(new String(stationBuffer), ByteBuffer.wrap(stationBuffer)); Stat res = new Stat(stationBuffer);
matches.add(res); matches.add(res);
return res; return res;
} }
} }
public static void parseStats(long startByte, long endByte, Map<String, Stat> results) { public static List<List<Stat>> parseStats(long startByte, long endByte) {
try { try {
RandomAccessFile file = new RandomAccessFile("./measurements.txt", "r"); RandomAccessFile file = new RandomAccessFile("./measurements.txt", "r");
long maxEnd = Math.min(file.length(), endByte + MAX_LINE_LENGTH); long maxEnd = Math.min(file.length(), endByte + MAX_LINE_LENGTH);
@ -117,30 +120,22 @@ public class CalculateAverage_ianopolous {
} }
} }
List<List<Station>> stations = new ArrayList<>(MAX_STATIONS); List<List<Stat>> stations = new ArrayList<>(MAX_STATIONS);
for (int i = 0; i < MAX_STATIONS; i++) for (int i = 0; i < MAX_STATIONS; i++)
stations.add(null); stations.add(null);
int lineStart = done; int lineStart = done;
int lineSplit = 0; int lineSplit = 0;
long temperature = 0; short temperature = 0;
int hash = 1; int hash = 1;
boolean negative = false; boolean negative = false;
while (done < maxDone) { while (done < maxDone) {
Station station = null; Stat station = null;
for (int i = done; i < done + MAX_LINE_LENGTH && i < maxEnd; i++) { for (int i = done; i < done + MAX_LINE_LENGTH && i < maxEnd; i++) {
byte b = buffer.get(i); byte b = buffer.get(i);
if (b == '\n') { if (b == '\n') {
done = i + 1; done = i + 1;
Stat res = results.get(station.name); temperature = negative ? (short) -temperature : temperature;
temperature = negative ? -temperature : temperature; station.add(temperature);
if (res != null) {
res.add(temperature);
}
else {
res = new Stat();
res.add(temperature);
results.put(station.name, res);
}
lineStart = done; lineStart = done;
station = null; station = null;
hash = 1; hash = 1;
@ -152,27 +147,35 @@ public class CalculateAverage_ianopolous {
temperature = 0; temperature = 0;
negative = false; negative = false;
} }
else if (b == '-' && station != null) { else if (station == null) {
negative = true;
}
else if (b != '.' && station != null) {
temperature = temperature * 10 + (b - 0x30);
}
else {
hash = 31 * hash + b; hash = 31 * hash + b;
} }
else if (b == '-') {
negative = true;
}
else if (b != '.') {
temperature = (short) (temperature * 10 + (b - 0x30));
} }
} }
} }
return stations;
}
catch (IOException e) { catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
public static class Stat { public static class Stat {
long min = Long.MAX_VALUE, max = Long.MIN_VALUE, total = 0, count = 0; final byte[] name;
int count = 0;
short min = Short.MAX_VALUE, max = Short.MIN_VALUE;
long total = 0;
public void add(long value) { public Stat(byte[] name) {
this.name = name;
}
public void add(short value) {
if (value < min) if (value < min)
min = value; min = value;
if (value > max) if (value > max)
@ -181,19 +184,24 @@ public class CalculateAverage_ianopolous {
count++; count++;
} }
public void merge(Stat value) { public Stat merge(Stat value) {
if (value.min < min) if (value.min < min)
min = value.min; min = value.min;
if (value.max > max) if (value.max > max)
max = value.max; max = value.max;
total += value.total; total += value.total;
count += value.count; count += value.count;
return this;
} }
private static double round(double value) { private static double round(double value) {
return Math.round(value) / 10.0; return Math.round(value) / 10.0;
} }
public String name() {
return new String(name);
}
public String toString() { public String toString() {
return round((double) min) + "/" + round(((double) total) / count) + "/" + round((double) max); return round((double) min) + "/" + round(((double) total) / count) + "/" + round((double) max);
} }