From 6aa63e1bd5e2d580324b8ddd58b69d11761b2bf3 Mon Sep 17 00:00:00 2001 From: Nick Palmer Date: Wed, 3 Jan 2024 22:18:40 +0000 Subject: [PATCH] Attempt nicer threading via streams and spliterators --- .../onebrc/CalculateAverage_palmr.java | 248 ++++++++---------- 1 file changed, 109 insertions(+), 139 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_palmr.java b/src/main/java/dev/morling/onebrc/CalculateAverage_palmr.java index bb57d33..c687031 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_palmr.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_palmr.java @@ -21,91 +21,68 @@ import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; public class CalculateAverage_palmr { - private static final String FILE = "./measurements.txt"; - public static final int CHUNK_SIZE = 1024 * 1024 * 10; // Trial and error showed ~10MB to be a good size on our machine - public static final int LITTLE_CHUNK_SIZE = 128; // Enough bytes to cover a station name and measurement value :fingers-crossed: - public static final int STATION_NAME_BUFFER_SIZE = 50; - public static final int THREAD_COUNT = Math.min(8, Runtime.getRuntime().availableProcessors()); + private static final int CHUNK_SIZE = 1024 * 1024 * 10; // Trial and error showed ~10MB to be a good size on our machine + private static final int STATION_NAME_BUFFER_SIZE = 50; + private static final int THREAD_COUNT = Math.min(8, Runtime.getRuntime().availableProcessors()); + private static final char SEPARATOR_CHAR = ';'; + private static final char NEWLINE_CHAR = '\n'; + private static final char MINUS_CHAR = '-'; + private static final char DECIMAL_POINT_CHAR = '.'; public static void main(String[] args) throws IOException { @SuppressWarnings("resource") // It's faster to leak the file than be well-behaved - RandomAccessFile file = new RandomAccessFile(FILE, "r"); - FileChannel channel = file.getChannel(); - long fileSize = channel.size(); + final var file = new RandomAccessFile(FILE, "r"); + final var channel = file.getChannel(); - long threadChunk = fileSize / THREAD_COUNT; - - Thread[] threads = new Thread[THREAD_COUNT]; - ByteArrayKeyedMap[] results = new ByteArrayKeyedMap[THREAD_COUNT]; - for (int i = 0; i < THREAD_COUNT; i++) { - final int j = i; - long startPoint = j * threadChunk; - long endPoint = startPoint + threadChunk; - Thread thread = new Thread(() -> { - try { - results[j] = readAndParse(channel, startPoint, endPoint, fileSize); - } - catch (Throwable t) { - System.err.println("It's broken :("); - // noinspection CallToPrintStackTrace - t.printStackTrace(); - } - }); - threads[i] = thread; - thread.start(); - } - - final Map finalAggregator = new TreeMap<>(); - - for (int i = 0; i < THREAD_COUNT; i++) { - try { - threads[i].join(); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - - results[i].getAsUnorderedList().forEach(v -> { - String stationName = new String(v.stationNameBytes, StandardCharsets.UTF_8); - finalAggregator.compute(stationName, (_, x) -> { - if (x == null) { - return v; - } - else { - x.count += v.count; - x.min = Math.min(x.min, v.min); - x.max = Math.max(x.max, v.max); - x.sum += v.sum; - return x; - } - }); - }); - } - System.out.println(finalAggregator); + final TreeMap results = StreamSupport.stream(ThreadChunk.chunk(file, THREAD_COUNT), true) + .map(chunk -> parseChunk(chunk, channel)) + .flatMap(bakm -> bakm.getAsUnorderedList().stream()) + .collect(Collectors.toMap(m -> new String(m.stationNameBytes, StandardCharsets.UTF_8), m -> m, MeasurementAggregator::merge, TreeMap::new)); + System.out.println(results); } - private static ByteArrayKeyedMap readAndParse(final FileChannel channel, - final long startPoint, - final long endPoint, - final long fileSize) { - final State state = new State(); + private record ThreadChunk(long startPoint, long endPoint, long size) { + public static Spliterator chunk(final RandomAccessFile file, final int chunkCount) throws IOException { + final var fileSize = file.length(); + final var idealChunkSize = fileSize / THREAD_COUNT; + final var chunks = new CalculateAverage_palmr.ThreadChunk[chunkCount]; - boolean skipFirstEntry = startPoint != 0; + var startPoint = 0L; + for (int i = 0; i < chunkCount; i++) { + var endPoint = Math.min(startPoint + idealChunkSize, fileSize); + file.seek(endPoint); + while (endPoint < fileSize && file.readByte() != NEWLINE_CHAR) { + endPoint++; + } + final var actualSize = endPoint - startPoint; + chunks[i] = new CalculateAverage_palmr.ThreadChunk(startPoint, endPoint, actualSize); + startPoint += actualSize; + } - long offset = startPoint; - while (offset < endPoint) { - parseData(channel, state, offset, Math.min(CHUNK_SIZE, fileSize - offset), false, skipFirstEntry); - skipFirstEntry = false; - offset += CHUNK_SIZE; + return Spliterators.spliterator(chunks, + Spliterator.ORDERED | + Spliterator.DISTINCT | + Spliterator.SORTED | + Spliterator.NONNULL | + Spliterator.IMMUTABLE | + Spliterator.CONCURRENT + ); } + } - if (offset < fileSize) { - // Make sure we finish reading any partially read entry by going a little in to the next chunk, stopping at the first newline - parseData(channel, state, offset, Math.min(LITTLE_CHUNK_SIZE, fileSize - offset), true, false); + private static ByteArrayKeyedMap parseChunk(ThreadChunk chunk, FileChannel channel) { + final var state = new State(); + + var offset = chunk.startPoint; + while (offset < chunk.endPoint) { + parseData(channel, state, offset, Math.min(CHUNK_SIZE, chunk.endPoint - offset)); + offset += CHUNK_SIZE; } return state.aggregators; @@ -114,69 +91,48 @@ public class CalculateAverage_palmr { private static void parseData(final FileChannel channel, final State state, final long offset, - final long bufferSize, - final boolean stopAtNewline, - final boolean skipFirstEntry) { - ByteBuffer byteBuffer; + final long bufferSize) { + final ByteBuffer byteBuffer; try { byteBuffer = channel.map(FileChannel.MapMode.READ_ONLY, offset, bufferSize); - } - catch (IOException e) { + + while (byteBuffer.hasRemaining()) { + final var currentChar = byteBuffer.get(); + + if (currentChar == SEPARATOR_CHAR) { + state.parsingValue = true; + } else if (currentChar == NEWLINE_CHAR) { + if (state.stationPointerEnd != 0) { + final var value = state.measurementValue * state.exponent; + + MeasurementAggregator aggregator = state.aggregators.computeIfAbsent(state.stationBuffer, state.stationPointerEnd, state.signedHashCode); + aggregator.count++; + aggregator.min = Math.min(aggregator.min, value); + aggregator.max = Math.max(aggregator.max, value); + aggregator.sum += value; + } + + // reset + state.reset(); + } else { + if (!state.parsingValue) { + state.stationBuffer[state.stationPointerEnd++] = currentChar; + state.signedHashCode = 31 * state.signedHashCode + (currentChar & 0xff); + } else { + if (currentChar == MINUS_CHAR) { + state.exponent = -0.1; + } else if (currentChar != DECIMAL_POINT_CHAR) { + state.measurementValue = state.measurementValue * 10 + (currentChar - '0'); + } + } + } + } + } catch (IOException e) { throw new RuntimeException(e); } - - boolean isSkippingToFirstCleanEntry = skipFirstEntry; - - while (byteBuffer.hasRemaining()) { - byte currentChar = byteBuffer.get(); - - if (isSkippingToFirstCleanEntry) { - if (currentChar == '\n') { - isSkippingToFirstCleanEntry = false; - } - - continue; - } - - if (currentChar == ';') { - state.parsingValue = true; - } - else if (currentChar == '\n') { - if (state.stationPointerEnd != 0) { - double value = state.measurementValue * state.exponent; - - MeasurementAggregator aggregator = state.aggregators.computeIfAbsent(state.stationBuffer, state.stationPointerEnd, state.signedHashCode); - aggregator.count++; - aggregator.min = Math.min(aggregator.min, value); - aggregator.max = Math.max(aggregator.max, value); - aggregator.sum += value; - } - - if (stopAtNewline) { - return; - } - - // reset - state.reset(); - } - else { - if (!state.parsingValue) { - state.stationBuffer[state.stationPointerEnd++] = currentChar; - state.signedHashCode = 31 * state.signedHashCode + (currentChar & 0xff); - } - else { - if (currentChar == '-') { - state.exponent = -0.1; - } - else if (currentChar != '.') { - state.measurementValue = state.measurementValue * 10 + (currentChar - '0'); - } - } - } - } } - static final class State { + private static final class State { ByteArrayKeyedMap aggregators = new ByteArrayKeyedMap(); boolean parsingValue = false; byte[] stationBuffer = new byte[STATION_NAME_BUFFER_SIZE]; @@ -208,37 +164,51 @@ public class CalculateAverage_palmr { } public String toString() { - return round(min) + "/" + round(sum / count) + "/" + round(max); + return STR."\{round(min)}/\{round(sum / count)}/\{round(max)}"; } - private double round(double value) { + private double round(final double value) { return Math.round(value * 10.0) / 10.0; } + + private MeasurementAggregator merge(final MeasurementAggregator b) { + this.count += b.count; + this.min = Math.min(this.min, b.min); + this.max = Math.max(this.max, b.max); + this.sum += b.sum; + return this; + } } + /** + * Very basic hash table implementation, only implementing computeIfAbsent since that's all the code needs. + * It's sized to give minimal collisions with the example test set. this may not hold true if the stations list + * changes, but it should still perform fairly well. + * It uses Open Addressing, meaning it's just one array, rather Separate Chaining which is what the default java HashMap uses. + * IT also uses Linear probing for collision resolution, which given the minimal collision count should hold up well. + */ private static class ByteArrayKeyedMap { private final int BUCKET_COUNT = 0xFFF; // 413 unique stations in the data set, & 0xFFF ~= 399 (only 14 collisions (given our hashcode implementation)) private final MeasurementAggregator[] buckets = new MeasurementAggregator[BUCKET_COUNT + 1]; private final List compactUnorderedBuckets = new ArrayList<>(413); public MeasurementAggregator computeIfAbsent(final byte[] key, final int keyLength, final int keyHashCode) { - int index = keyHashCode & BUCKET_COUNT; + var index = keyHashCode & BUCKET_COUNT; while (true) { MeasurementAggregator maybe = buckets[index]; - if (maybe == null) { - final byte[] copiedKey = Arrays.copyOf(key, keyLength); - MeasurementAggregator measurementAggregator = new MeasurementAggregator(copiedKey, keyHashCode); - buckets[index] = measurementAggregator; - compactUnorderedBuckets.add(measurementAggregator); - return measurementAggregator; - } - else { + if (maybe != null) { if (Arrays.equals(key, 0, keyLength, maybe.stationNameBytes, 0, maybe.stationNameBytes.length)) { return maybe; } index++; index &= BUCKET_COUNT; + } else { + final var copiedKey = Arrays.copyOf(key, keyLength); + MeasurementAggregator measurementAggregator = new MeasurementAggregator(copiedKey, keyHashCode); + buckets[index] = measurementAggregator; + compactUnorderedBuckets.add(measurementAggregator); + return measurementAggregator; } } }