From 08af2622d39bae9426c56d9d25913f21627745cc Mon Sep 17 00:00:00 2001 From: greid Date: Sun, 7 Jan 2024 14:27:23 +0100 Subject: [PATCH] gabrielreid take 2 Clear up some TODOS, simplify the code a bit, which appears to result in a 25% performance increase. --- .../onebrc/CalculateAverage_gabrielreid.java | 303 ++++++++---------- 1 file changed, 141 insertions(+), 162 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_gabrielreid.java b/src/main/java/dev/morling/onebrc/CalculateAverage_gabrielreid.java index a682841..7c51e4c 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_gabrielreid.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_gabrielreid.java @@ -22,10 +22,10 @@ import java.io.FileInputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.HashMap; import java.util.Locale; +import java.util.Map; import java.util.Objects; -import java.util.SortedMap; -import java.util.TreeMap; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.LinkedBlockingDeque; @@ -45,12 +45,13 @@ public class CalculateAverage_gabrielreid { private static final int BLOCK_READ_SIZE = 1024 * 1024 * 16; private static final int SUMMARY_TABLE_SIZE = 2048; + private static final int MAP_INITIAL_SIZE = 450; /** * State with the full summary table, as well as leftover bytes between processed blocks that need to * be handled afterward. */ - record State(SummaryTable summaryTable, byte[] remainderBytes) { + record State(Map map, byte[] remainderBytes) { } public static void main(String[] args) throws IOException { @@ -63,7 +64,7 @@ public class CalculateAverage_gabrielreid { } try (var fjp = new ForkJoinPool(numCores)) { - CompletableFuture stateFuture = CompletableFuture.completedFuture(new State(new SummaryTable(SUMMARY_TABLE_SIZE), new byte[0])); + CompletableFuture stateFuture = CompletableFuture.completedFuture(new State(new HashMap<>(MAP_INITIAL_SIZE), new byte[0])); try (var fis = new FileInputStream("./measurements.txt")) { var blockBuilder = Objects.requireNonNull(blockBuilderQueue.poll()); @@ -77,7 +78,7 @@ public class CalculateAverage_gabrielreid { skipToNewline = true; stateFuture = stateFuture.thenCombine( CompletableFuture.supplyAsync(() -> { - var summaryTable = localBlockBuilder.buildSummaryTable(localCnt, localSkipToNewline); + var summaryMap = localBlockBuilder.buildSummaryTable(localCnt, localSkipToNewline); int unprocessedRemainderSize = localBlockBuilder.firstLineStart + (localCnt - localBlockBuilder.lastLineEnd); var unprocessedBytes = new byte[unprocessedRemainderSize]; @@ -87,20 +88,20 @@ public class CalculateAverage_gabrielreid { localBlockBuilder.reset(); blockBuilderQueue.add(localBlockBuilder); - return new State(summaryTable, unprocessedBytes); + return new State(summaryMap, unprocessedBytes); }, fjp), (state, newState) -> { - state.summaryTable.addAll(newState.summaryTable); + newState.map.forEach( + (k, v) -> state.map.merge(k, v, CitySummary::add)); var newRemainderBytes = new byte[state.remainderBytes.length + newState.remainderBytes.length]; System.arraycopy(state.remainderBytes, 0, newRemainderBytes, 0, state.remainderBytes.length); System.arraycopy(newState.remainderBytes, 0, newRemainderBytes, state.remainderBytes.length, newState.remainderBytes.length); - return new State(state.summaryTable, newRemainderBytes); + return new State(state.map, newRemainderBytes); }); try { blockBuilder = blockBuilderQueue.poll(1, TimeUnit.HOURS); - } - catch (InterruptedException e) { + } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException(e); } @@ -109,33 +110,38 @@ public class CalculateAverage_gabrielreid { } stateFuture = stateFuture.thenApply(state -> { - BlockBuilder blockBuilder = null; + BlockBuilder blockBuilder; try { blockBuilder = blockBuilderQueue.poll(1, TimeUnit.HOURS); - } - catch (InterruptedException e) { + } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException(e); } System.arraycopy(state.remainderBytes, 0, blockBuilder.readBuffer, 0, state.remainderBytes.length); - state.summaryTable.addAll(blockBuilder.buildSummaryTable(state.remainderBytes.length, false)); - return new State(state.summaryTable, new byte[0]); + + var m = blockBuilder.buildSummaryTable(state.remainderBytes.length, false); + m.forEach( + (k, v) -> state.map.merge(k, v, CitySummary::add)); + return new State(state.map, new byte[0]); }); var state = stateFuture.join(); - System.out.println(state.summaryTable.toOutputString()); - } - } + System.out.println(STR."{\{state.map.entrySet().stream().sorted(Map.Entry.comparingByKey()) + .map(e -> String.format(Locale.US, "%s=%.1f/%.1f/%.1f", e.getKey(), e.getValue().min / 10f, + (e.getValue().sum / (float) e.getValue().count) / 10f, e.getValue().max / 10f)) + .collect(Collectors.joining(", "))}}"); + + }} /** * Parses number values as integers from the byte array. *

* The multiplier is 1 if positive and -1 if negative. */ - static short parseNumFromLine(byte[] buf, int offset, int len, int multiplier) { + static short parseNumFromLine(byte[] buf, int offset, int len) { return switch (len) { - case 3 -> (short) ((((buf[offset] - '0') * 10) + buf[offset + 2] - '0') * multiplier); - case 4 -> (short) ((((buf[offset] - '0') * 100) + (buf[offset + 1] - '0') * 10 + buf[offset + 3] - '0') * multiplier); + case 3 -> (short) ((((buf[offset] - '0') * 10) + buf[offset + 2] - '0')); + case 4 -> (short) ((((buf[offset] - '0') * 100) + (buf[offset + 1] - '0') * 10 + buf[offset + 3] - '0')); default -> throw new IllegalStateException("Unexpected number length %d".formatted(len)); }; } @@ -163,11 +169,12 @@ public class CalculateAverage_gabrielreid { this.count++; } - void add(CitySummary other) { + CitySummary add(CitySummary other) { this.max = Math.max(other.max, this.max); this.min = Math.min(other.min, this.min); this.sum += other.sum; this.count += other.count; + return this; } } @@ -180,43 +187,35 @@ public class CalculateAverage_gabrielreid { static final class ByteSlice { private final byte[] buf; - private int offset; - private int len; - private int hash; + private final int offset; + private final int len; - public ByteSlice(byte[] buf) { + public ByteSlice(byte[] buf, int offset, int len) { this.buf = buf; - } - - public static ByteSlice clone(ByteSlice src) { - var bytes = new byte[src.len]; - System.arraycopy(src.buf, src.offset, bytes, 0, src.len); - var copy = new ByteSlice(bytes); - copy.offset = 0; - copy.len = src.len; - copy.hash = src.hash; - return copy; - } - - @Override - public String toString() { - return "ByteSlice[%s]".formatted(new String(this.buf, this.offset, this.len, StandardCharsets.UTF_8)); + this.offset = offset; + this.len = len; } public String valueAsString() { return new String(this.buf, this.offset, this.len, StandardCharsets.UTF_8); } - @Override public int hashCode() { - return this.hash; + return hashCode(this.buf, this.offset, this.len); } - public int calculateHashCode() { + public static int hashCode(byte[] buf, int offset, int len) { int result = 1; - int end = offset + len; - for (int i = offset; i < end; i++) { - result = 31 * result + buf[i]; + int i = 0; + for (; i + 3 < len; i += 4) { + result = 31 * 31 * 31 * 31 * result + + 31 * 31 * 31 * buf[offset + i] + + 31 * 31 * buf[offset + i + 1] + + 31 * buf[offset + i + 2] + + buf[offset + i + 3]; + } + for (; i < len; i++) { + result = 31 * result + buf[offset + i]; } return result; } @@ -230,87 +229,119 @@ public class CalculateAverage_gabrielreid { } public static boolean equal(ByteSlice a, ByteSlice b) { - return a.hash == b.hash && Arrays.equals(a.buf, a.offset, a.offset + a.len, b.buf, b.offset, b.offset + b.len); + return Arrays.equals(a.buf, a.offset, a.offset + a.len, b.buf, b.offset, b.offset + b.len); } + public static boolean equal(ByteSlice a, byte[] buf, int offset, int len) { + return Arrays.equals(a.buf, a.offset, a.offset + a.len, buf, offset, offset + len); + } + + } + + record ValueNode(ByteSlice byteSlice, CitySummary citySummary) { + } static final class SummaryTable { + private static final int LOAD_FACTOR = 4; + private int size; - private ByteSlice[] keys; - private CitySummary[] values; + private ValueNode[] values; private int valueCount; + private int resizeThreshold; + + private byte[] localBufferBytes = new byte[MAP_INITIAL_SIZE * 100]; + private int localBufferPtr = 0; SummaryTable(int size) { this.size = size; - this.keys = new ByteSlice[size]; - this.values = new CitySummary[size]; + this.values = new ValueNode[size]; + this.resizeThreshold = size / LOAD_FACTOR; + } + + void reset() { + for (int i = 0; i < size; i++) { + this.values[i] = null; + } + localBufferPtr = 0; } public void addAll(SummaryTable other) { for (int i = 0; i < other.size; i++) { - var otherSlice = other.keys[i]; + var otherSlice = other.values[i]; if (otherSlice != null) { - putCitySummary(otherSlice, other.values[i]); + putValueNode(otherSlice); } } } - private void putCitySummary(ByteSlice key, CitySummary value) { - resizeIfNecessary(); - int index = (key.hash & 0x7FFFFFFF) % size; - while (keys[index] != null) { - if (ByteSlice.equal(keys[index], key)) { - values[index].add(value); + private void putValueNode(ValueNode valueNode) { + int hashCode = valueNode.byteSlice.hashCode(); + int index = (hashCode & 0x7FFFFFFF) % size; + while (values[index] != null) { + if (ByteSlice.equal(values[index].byteSlice, valueNode.byteSlice)) { + values[index].citySummary.add(valueNode.citySummary); return; } - // TODO Consider secondary hash for the stride here - index = (index + 1) % size; + index = (index + (hashCode & 0xFF) + 1) % size; } - keys[index] = key; - values[index] = value; + values[index] = valueNode; valueCount++; + resizeIfNecessary(); } - public void putTemperatureValue(ByteSlice key, int value) { - resizeIfNecessary(); - int index = (key.hash & 0x7FFFFFFF) % size; - while (keys[index] != null) { - if (ByteSlice.equal(keys[index], key)) { - values[index].add(value); + public void putTemperatureValue(byte[] buf, int offset, int len, int value) { + + int hashCode = ByteSlice.hashCode(buf, offset, len); + int index = (hashCode & 0x7FFFFFFF) % size; + while (values[index] != null) { + if (ByteSlice.equal(values[index].byteSlice, buf, offset, len)) { + values[index].citySummary.add(value); return; } - // TODO Consider secondary hash for the stride here - // System.out.println("Collision!"); - index = (index + 1) % size; + index = (index + (hashCode & 0xFF) + 1) % size; } - keys[index] = ByteSlice.clone(key); - values[index] = new CitySummary(value); + + System.arraycopy(buf, offset, this.localBufferBytes, this.localBufferPtr, len); + var byteSlice = new ByteSlice(this.localBufferBytes, this.localBufferPtr, len); + var valueNode = new ValueNode(byteSlice, new CitySummary(value)); + localBufferPtr += len; + + values[index] = valueNode; valueCount++; + resizeIfNecessary(); } private void resizeIfNecessary() { - if (valueCount == size) { - var resized = new SummaryTable(size * 2); + if (valueCount >= resizeThreshold) { + int newSize = size * 2; + var resized = new SummaryTable(newSize); + for (int i = 0; i < this.size; i++) { + if (this.values[i] != null) { + resized.putValueNode(this.values[i]); + } + } resized.addAll(this); - this.keys = resized.keys; + byte[] localBufferBytes = new byte[this.localBufferBytes.length * 2]; + System.arraycopy(this.localBufferBytes, 0, localBufferBytes, 0, this.localBufferPtr); this.values = resized.values; - this.size = resized.size; + this.size = newSize; this.valueCount = resized.valueCount; + this.resizeThreshold = newSize / LOAD_FACTOR; + this.localBufferBytes = localBufferBytes; } } - public String toOutputString() { - SortedMap m = new TreeMap<>(); + public Map toMap() { + HashMap m = HashMap.newHashMap(valueCount); for (int i = 0; i < size; i++) { - ByteSlice slice = keys[i]; - if (slice != null) { - m.put(slice.valueAsString(), values[i]); + var valueNode = this.values[i]; + if (valueNode != null) { + m.put(valueNode.byteSlice.valueAsString(), valueNode.citySummary); } } - return "{" + m.entrySet().stream().map(e -> String.format(Locale.US, "%s=%.1f/%.1f/%.1f", e.getKey(), e.getValue().min / 10f, - (e.getValue().sum / (float) e.getValue().count) / 10f, e.getValue().max / 10f)).collect(Collectors.joining(", ")) + "}"; + return m; } } @@ -319,79 +350,26 @@ public class CalculateAverage_gabrielreid { */ static class BlockBuilder { final byte[] readBuffer; - private final int numSegmentLengths; - private final byte[] segmentLengths; - private final int[] hashCodes; - private final short[] temperatureValues; - private final ByteSlice byteSlice; + private final SummaryTable summaryTable; private int firstLineStart; private int lastLineEnd; - private int lineCount; public BlockBuilder(int readBufferSize) { this.readBuffer = new byte[readBufferSize]; - // TODO This sizing is almost certainly non-optimal, but it seems to work - this.numSegmentLengths = readBufferSize / 2; - this.segmentLengths = new byte[numSegmentLengths]; - this.hashCodes = new int[numSegmentLengths]; - this.temperatureValues = new short[numSegmentLengths]; - this.byteSlice = new ByteSlice(this.readBuffer); + this.summaryTable = new SummaryTable(SUMMARY_TABLE_SIZE); } void reset() { firstLineStart = -1; lastLineEnd = -1; - lineCount = 0; + this.summaryTable.reset(); + } - public SummaryTable buildSummaryTable(int readByteCount, boolean skipToNewline) { - + public Map buildSummaryTable(int readByteCount, boolean skipToNewline) { parseLineSegments(readByteCount, skipToNewline); - calculateHashesAndTemperatures(); - - SummaryTable summaryTable = new SummaryTable(SUMMARY_TABLE_SIZE); - - this.byteSlice.offset = this.firstLineStart; - - int segmentCount = lineCount * 2; - int lineCounter = 0; - for (int segmentIdx = 0; segmentIdx < segmentCount; segmentIdx += 2) { - // TODO It would likely be better if this ByteSlice was just a view on the arrays - this.byteSlice.len = this.segmentLengths[segmentIdx]; - this.byteSlice.hash = this.hashCodes[lineCounter]; - summaryTable.putTemperatureValue(this.byteSlice, this.temperatureValues[lineCounter]); - this.byteSlice.offset += (this.segmentLengths[segmentIdx] + (this.segmentLengths[segmentIdx + 1] & 0b01111111) + 2); - lineCounter++; - } - - return summaryTable; - - } - - private void calculateHashesAndTemperatures() { - this.byteSlice.offset = this.firstLineStart; - - int segmentCount = lineCount * 2; - int lineCounter = 0; - for (int segmentIdx = 0; segmentIdx < segmentCount; segmentIdx += 2) { - // TODO It would likely be better if this ByteSlice was just a view on the arrays - this.byteSlice.len = this.segmentLengths[segmentIdx]; - this.hashCodes[lineCounter++] = this.byteSlice.calculateHashCode(); - this.byteSlice.offset += (this.segmentLengths[segmentIdx] + (this.segmentLengths[segmentIdx + 1] & 0b01111111) + 2); - } - - // TODO It might be better/faster to do this in the previous loop instead of second loop - lineCounter = 0; - int offset = this.firstLineStart + this.segmentLengths[0] + 1; - for (int segmentIdx = 1; segmentIdx < segmentCount; segmentIdx += 2) { - byte segmentLength = this.segmentLengths[segmentIdx]; - var isNeg = (byte) (segmentLength >> 7); - int numLength = (segmentLength & 0b01111111) + isNeg; - short temperatureValue = parseNumFromLine(this.readBuffer, offset - isNeg, numLength, (isNeg * 2) + 1); - this.temperatureValues[lineCounter++] = temperatureValue; - offset += -isNeg + numLength + this.segmentLengths[segmentIdx + 1] + 2; - } + return summaryTable.toMap(); } private void parseLineSegments(int byteCount, boolean skipToNewline) { @@ -406,22 +384,23 @@ public class CalculateAverage_gabrielreid { } this.firstLineStart = idx; - this.lineCount = 0; - int segmentCounter = 0; int lineStart = idx; - while (idx < upperBound && lineCount < numSegmentLengths) { + while (idx < upperBound) { var byteVector = ByteVector.fromArray(BYTE_SPECIES, readBuffer, idx); - var newlineByteMask = byteVector.eq(NEWLINE_BYTE); - if (newlineByteMask.anyTrue()) { + var newlineIdx = byteVector.eq(NEWLINE_BYTE).firstTrue(); + if (newlineIdx < BYTE_SPECIES_LEN) { var semicolonByteMask = byteVector.eq(SEMICOLON_BYTE); var semicolonIdx = semicolonByteMask.firstTrue(); + int semicolonOffset = idx + semicolonIdx; - int lineEnd = idx + newlineByteMask.firstTrue(); - byte negMult = (byte) ((readBuffer[idx + semicolonIdx + 1] == NEG_BYTE) ? 0b10000000 : 0); - this.segmentLengths[segmentCounter++] = (byte) (semicolonOffset - lineStart); - this.segmentLengths[segmentCounter++] = (byte) ((lineEnd - (semicolonOffset + 1)) | negMult); - this.lineCount++; + int lineEnd = idx + newlineIdx; + short negative = (short) ((readBuffer[idx + semicolonIdx + 1] == NEG_BYTE) ? 1 : 0); + int numLength = (byte) ((lineEnd - (semicolonOffset + 1))); + var num = parseNumFromLine(this.readBuffer, semicolonOffset + 1 + negative, numLength - negative); + num = negative == 1 ? (short) -num : num; + int nameLen = semicolonOffset - lineStart; + summaryTable.putTemperatureValue(this.readBuffer, lineStart, nameLen, num); idx = lineEnd + 1; lastLineEnd = idx; lineStart = idx; @@ -439,19 +418,19 @@ public class CalculateAverage_gabrielreid { } if (readBuffer[idx] == NEWLINE_BYTE) { int lineEnd = idx; - byte negMult = (byte) ((readBuffer[semicolonIdx + 1] == NEG_BYTE) ? 0b10000000 : 0); - this.segmentLengths[segmentCounter++] = (byte) (semicolonIdx - lineStart); - this.segmentLengths[segmentCounter++] = (byte) ((lineEnd - (semicolonIdx + 1)) | negMult); - this.lineCount++; + + short negative = (short) ((readBuffer[semicolonIdx + 1] == NEG_BYTE) ? 1 : 0); + int numLength = (byte) ((lineEnd - (semicolonIdx + 1))); + var num = parseNumFromLine(this.readBuffer, semicolonIdx + 1 + negative, numLength - negative); + num = negative == 1 ? (short) -num : num; + int nameLen = semicolonIdx - lineStart; + summaryTable.putTemperatureValue(this.readBuffer, lineStart, nameLen, num); idx = lineEnd + 1; lastLineEnd = idx; lineStart = idx; } idx++; } - } - } - }