gabrielreid take 2

Clear up some TODOS, simplify the code a bit, which appears to
result in a 25% performance increase.
This commit is contained in:
greid 2024-01-07 14:27:23 +01:00 committed by Gunnar Morling
parent a8a3876416
commit 08af2622d3

View File

@ -22,10 +22,10 @@ import java.io.FileInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.LinkedBlockingDeque;
@ -45,12 +45,13 @@ public class CalculateAverage_gabrielreid {
private static final int BLOCK_READ_SIZE = 1024 * 1024 * 16;
private static final int SUMMARY_TABLE_SIZE = 2048;
private static final int MAP_INITIAL_SIZE = 450;
/**
* State with the full summary table, as well as leftover bytes between processed blocks that need to
* be handled afterward.
*/
record State(SummaryTable summaryTable, byte[] remainderBytes) {
record State(Map<String, CitySummary> map, byte[] remainderBytes) {
}
public static void main(String[] args) throws IOException {
@ -63,7 +64,7 @@ public class CalculateAverage_gabrielreid {
}
try (var fjp = new ForkJoinPool(numCores)) {
CompletableFuture<State> stateFuture = CompletableFuture.completedFuture(new State(new SummaryTable(SUMMARY_TABLE_SIZE), new byte[0]));
CompletableFuture<State> stateFuture = CompletableFuture.completedFuture(new State(new HashMap<>(MAP_INITIAL_SIZE), new byte[0]));
try (var fis = new FileInputStream("./measurements.txt")) {
var blockBuilder = Objects.requireNonNull(blockBuilderQueue.poll());
@ -77,7 +78,7 @@ public class CalculateAverage_gabrielreid {
skipToNewline = true;
stateFuture = stateFuture.thenCombine(
CompletableFuture.supplyAsync(() -> {
var summaryTable = localBlockBuilder.buildSummaryTable(localCnt, localSkipToNewline);
var summaryMap = localBlockBuilder.buildSummaryTable(localCnt, localSkipToNewline);
int unprocessedRemainderSize = localBlockBuilder.firstLineStart + (localCnt - localBlockBuilder.lastLineEnd);
var unprocessedBytes = new byte[unprocessedRemainderSize];
@ -87,20 +88,20 @@ public class CalculateAverage_gabrielreid {
localBlockBuilder.reset();
blockBuilderQueue.add(localBlockBuilder);
return new State(summaryTable, unprocessedBytes);
return new State(summaryMap, unprocessedBytes);
}, fjp), (state, newState) -> {
state.summaryTable.addAll(newState.summaryTable);
newState.map.forEach(
(k, v) -> state.map.merge(k, v, CitySummary::add));
var newRemainderBytes = new byte[state.remainderBytes.length + newState.remainderBytes.length];
System.arraycopy(state.remainderBytes, 0, newRemainderBytes, 0, state.remainderBytes.length);
System.arraycopy(newState.remainderBytes, 0, newRemainderBytes, state.remainderBytes.length, newState.remainderBytes.length);
return new State(state.summaryTable, newRemainderBytes);
return new State(state.map, newRemainderBytes);
});
try {
blockBuilder = blockBuilderQueue.poll(1, TimeUnit.HOURS);
}
catch (InterruptedException e) {
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
@ -109,33 +110,38 @@ public class CalculateAverage_gabrielreid {
}
stateFuture = stateFuture.thenApply(state -> {
BlockBuilder blockBuilder = null;
BlockBuilder blockBuilder;
try {
blockBuilder = blockBuilderQueue.poll(1, TimeUnit.HOURS);
}
catch (InterruptedException e) {
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
System.arraycopy(state.remainderBytes, 0, blockBuilder.readBuffer, 0, state.remainderBytes.length);
state.summaryTable.addAll(blockBuilder.buildSummaryTable(state.remainderBytes.length, false));
return new State(state.summaryTable, new byte[0]);
var m = blockBuilder.buildSummaryTable(state.remainderBytes.length, false);
m.forEach(
(k, v) -> state.map.merge(k, v, CitySummary::add));
return new State(state.map, new byte[0]);
});
var state = stateFuture.join();
System.out.println(state.summaryTable.toOutputString());
}
}
System.out.println(STR."{\{state.map.entrySet().stream().sorted(Map.Entry.comparingByKey())
.map(e -> String.format(Locale.US, "%s=%.1f/%.1f/%.1f", e.getKey(), e.getValue().min / 10f,
(e.getValue().sum / (float) e.getValue().count) / 10f, e.getValue().max / 10f))
.collect(Collectors.joining(", "))}}");
}}
/**
* Parses number values as integers from the byte array.
* <p>
* The multiplier is 1 if positive and -1 if negative.
*/
static short parseNumFromLine(byte[] buf, int offset, int len, int multiplier) {
static short parseNumFromLine(byte[] buf, int offset, int len) {
return switch (len) {
case 3 -> (short) ((((buf[offset] - '0') * 10) + buf[offset + 2] - '0') * multiplier);
case 4 -> (short) ((((buf[offset] - '0') * 100) + (buf[offset + 1] - '0') * 10 + buf[offset + 3] - '0') * multiplier);
case 3 -> (short) ((((buf[offset] - '0') * 10) + buf[offset + 2] - '0'));
case 4 -> (short) ((((buf[offset] - '0') * 100) + (buf[offset + 1] - '0') * 10 + buf[offset + 3] - '0'));
default -> throw new IllegalStateException("Unexpected number length %d".formatted(len));
};
}
@ -163,11 +169,12 @@ public class CalculateAverage_gabrielreid {
this.count++;
}
void add(CitySummary other) {
CitySummary add(CitySummary other) {
this.max = Math.max(other.max, this.max);
this.min = Math.min(other.min, this.min);
this.sum += other.sum;
this.count += other.count;
return this;
}
}
@ -180,43 +187,35 @@ public class CalculateAverage_gabrielreid {
static final class ByteSlice {
private final byte[] buf;
private int offset;
private int len;
private int hash;
private final int offset;
private final int len;
public ByteSlice(byte[] buf) {
public ByteSlice(byte[] buf, int offset, int len) {
this.buf = buf;
}
public static ByteSlice clone(ByteSlice src) {
var bytes = new byte[src.len];
System.arraycopy(src.buf, src.offset, bytes, 0, src.len);
var copy = new ByteSlice(bytes);
copy.offset = 0;
copy.len = src.len;
copy.hash = src.hash;
return copy;
}
@Override
public String toString() {
return "ByteSlice[%s]".formatted(new String(this.buf, this.offset, this.len, StandardCharsets.UTF_8));
this.offset = offset;
this.len = len;
}
public String valueAsString() {
return new String(this.buf, this.offset, this.len, StandardCharsets.UTF_8);
}
@Override
public int hashCode() {
return this.hash;
return hashCode(this.buf, this.offset, this.len);
}
public int calculateHashCode() {
public static int hashCode(byte[] buf, int offset, int len) {
int result = 1;
int end = offset + len;
for (int i = offset; i < end; i++) {
result = 31 * result + buf[i];
int i = 0;
for (; i + 3 < len; i += 4) {
result = 31 * 31 * 31 * 31 * result
+ 31 * 31 * 31 * buf[offset + i]
+ 31 * 31 * buf[offset + i + 1]
+ 31 * buf[offset + i + 2]
+ buf[offset + i + 3];
}
for (; i < len; i++) {
result = 31 * result + buf[offset + i];
}
return result;
}
@ -230,87 +229,119 @@ public class CalculateAverage_gabrielreid {
}
public static boolean equal(ByteSlice a, ByteSlice b) {
return a.hash == b.hash && Arrays.equals(a.buf, a.offset, a.offset + a.len, b.buf, b.offset, b.offset + b.len);
return Arrays.equals(a.buf, a.offset, a.offset + a.len, b.buf, b.offset, b.offset + b.len);
}
public static boolean equal(ByteSlice a, byte[] buf, int offset, int len) {
return Arrays.equals(a.buf, a.offset, a.offset + a.len, buf, offset, offset + len);
}
}
record ValueNode(ByteSlice byteSlice, CitySummary citySummary) {
}
static final class SummaryTable {
private static final int LOAD_FACTOR = 4;
private int size;
private ByteSlice[] keys;
private CitySummary[] values;
private ValueNode[] values;
private int valueCount;
private int resizeThreshold;
private byte[] localBufferBytes = new byte[MAP_INITIAL_SIZE * 100];
private int localBufferPtr = 0;
SummaryTable(int size) {
this.size = size;
this.keys = new ByteSlice[size];
this.values = new CitySummary[size];
this.values = new ValueNode[size];
this.resizeThreshold = size / LOAD_FACTOR;
}
void reset() {
for (int i = 0; i < size; i++) {
this.values[i] = null;
}
localBufferPtr = 0;
}
public void addAll(SummaryTable other) {
for (int i = 0; i < other.size; i++) {
var otherSlice = other.keys[i];
var otherSlice = other.values[i];
if (otherSlice != null) {
putCitySummary(otherSlice, other.values[i]);
putValueNode(otherSlice);
}
}
}
private void putCitySummary(ByteSlice key, CitySummary value) {
resizeIfNecessary();
int index = (key.hash & 0x7FFFFFFF) % size;
while (keys[index] != null) {
if (ByteSlice.equal(keys[index], key)) {
values[index].add(value);
private void putValueNode(ValueNode valueNode) {
int hashCode = valueNode.byteSlice.hashCode();
int index = (hashCode & 0x7FFFFFFF) % size;
while (values[index] != null) {
if (ByteSlice.equal(values[index].byteSlice, valueNode.byteSlice)) {
values[index].citySummary.add(valueNode.citySummary);
return;
}
// TODO Consider secondary hash for the stride here
index = (index + 1) % size;
index = (index + (hashCode & 0xFF) + 1) % size;
}
keys[index] = key;
values[index] = value;
values[index] = valueNode;
valueCount++;
resizeIfNecessary();
}
public void putTemperatureValue(ByteSlice key, int value) {
resizeIfNecessary();
int index = (key.hash & 0x7FFFFFFF) % size;
while (keys[index] != null) {
if (ByteSlice.equal(keys[index], key)) {
values[index].add(value);
public void putTemperatureValue(byte[] buf, int offset, int len, int value) {
int hashCode = ByteSlice.hashCode(buf, offset, len);
int index = (hashCode & 0x7FFFFFFF) % size;
while (values[index] != null) {
if (ByteSlice.equal(values[index].byteSlice, buf, offset, len)) {
values[index].citySummary.add(value);
return;
}
// TODO Consider secondary hash for the stride here
// System.out.println("Collision!");
index = (index + 1) % size;
index = (index + (hashCode & 0xFF) + 1) % size;
}
keys[index] = ByteSlice.clone(key);
values[index] = new CitySummary(value);
System.arraycopy(buf, offset, this.localBufferBytes, this.localBufferPtr, len);
var byteSlice = new ByteSlice(this.localBufferBytes, this.localBufferPtr, len);
var valueNode = new ValueNode(byteSlice, new CitySummary(value));
localBufferPtr += len;
values[index] = valueNode;
valueCount++;
resizeIfNecessary();
}
private void resizeIfNecessary() {
if (valueCount == size) {
var resized = new SummaryTable(size * 2);
if (valueCount >= resizeThreshold) {
int newSize = size * 2;
var resized = new SummaryTable(newSize);
for (int i = 0; i < this.size; i++) {
if (this.values[i] != null) {
resized.putValueNode(this.values[i]);
}
}
resized.addAll(this);
this.keys = resized.keys;
byte[] localBufferBytes = new byte[this.localBufferBytes.length * 2];
System.arraycopy(this.localBufferBytes, 0, localBufferBytes, 0, this.localBufferPtr);
this.values = resized.values;
this.size = resized.size;
this.size = newSize;
this.valueCount = resized.valueCount;
this.resizeThreshold = newSize / LOAD_FACTOR;
this.localBufferBytes = localBufferBytes;
}
}
public String toOutputString() {
SortedMap<String, CitySummary> m = new TreeMap<>();
public Map<String, CitySummary> toMap() {
HashMap<String, CitySummary> m = HashMap.newHashMap(valueCount);
for (int i = 0; i < size; i++) {
ByteSlice slice = keys[i];
if (slice != null) {
m.put(slice.valueAsString(), values[i]);
var valueNode = this.values[i];
if (valueNode != null) {
m.put(valueNode.byteSlice.valueAsString(), valueNode.citySummary);
}
}
return "{" + m.entrySet().stream().map(e -> String.format(Locale.US, "%s=%.1f/%.1f/%.1f", e.getKey(), e.getValue().min / 10f,
(e.getValue().sum / (float) e.getValue().count) / 10f, e.getValue().max / 10f)).collect(Collectors.joining(", ")) + "}";
return m;
}
}
@ -319,79 +350,26 @@ public class CalculateAverage_gabrielreid {
*/
static class BlockBuilder {
final byte[] readBuffer;
private final int numSegmentLengths;
private final byte[] segmentLengths;
private final int[] hashCodes;
private final short[] temperatureValues;
private final ByteSlice byteSlice;
private final SummaryTable summaryTable;
private int firstLineStart;
private int lastLineEnd;
private int lineCount;
public BlockBuilder(int readBufferSize) {
this.readBuffer = new byte[readBufferSize];
// TODO This sizing is almost certainly non-optimal, but it seems to work
this.numSegmentLengths = readBufferSize / 2;
this.segmentLengths = new byte[numSegmentLengths];
this.hashCodes = new int[numSegmentLengths];
this.temperatureValues = new short[numSegmentLengths];
this.byteSlice = new ByteSlice(this.readBuffer);
this.summaryTable = new SummaryTable(SUMMARY_TABLE_SIZE);
}
void reset() {
firstLineStart = -1;
lastLineEnd = -1;
lineCount = 0;
this.summaryTable.reset();
}
public SummaryTable buildSummaryTable(int readByteCount, boolean skipToNewline) {
public Map<String, CitySummary> buildSummaryTable(int readByteCount, boolean skipToNewline) {
parseLineSegments(readByteCount, skipToNewline);
calculateHashesAndTemperatures();
SummaryTable summaryTable = new SummaryTable(SUMMARY_TABLE_SIZE);
this.byteSlice.offset = this.firstLineStart;
int segmentCount = lineCount * 2;
int lineCounter = 0;
for (int segmentIdx = 0; segmentIdx < segmentCount; segmentIdx += 2) {
// TODO It would likely be better if this ByteSlice was just a view on the arrays
this.byteSlice.len = this.segmentLengths[segmentIdx];
this.byteSlice.hash = this.hashCodes[lineCounter];
summaryTable.putTemperatureValue(this.byteSlice, this.temperatureValues[lineCounter]);
this.byteSlice.offset += (this.segmentLengths[segmentIdx] + (this.segmentLengths[segmentIdx + 1] & 0b01111111) + 2);
lineCounter++;
}
return summaryTable;
}
private void calculateHashesAndTemperatures() {
this.byteSlice.offset = this.firstLineStart;
int segmentCount = lineCount * 2;
int lineCounter = 0;
for (int segmentIdx = 0; segmentIdx < segmentCount; segmentIdx += 2) {
// TODO It would likely be better if this ByteSlice was just a view on the arrays
this.byteSlice.len = this.segmentLengths[segmentIdx];
this.hashCodes[lineCounter++] = this.byteSlice.calculateHashCode();
this.byteSlice.offset += (this.segmentLengths[segmentIdx] + (this.segmentLengths[segmentIdx + 1] & 0b01111111) + 2);
}
// TODO It might be better/faster to do this in the previous loop instead of second loop
lineCounter = 0;
int offset = this.firstLineStart + this.segmentLengths[0] + 1;
for (int segmentIdx = 1; segmentIdx < segmentCount; segmentIdx += 2) {
byte segmentLength = this.segmentLengths[segmentIdx];
var isNeg = (byte) (segmentLength >> 7);
int numLength = (segmentLength & 0b01111111) + isNeg;
short temperatureValue = parseNumFromLine(this.readBuffer, offset - isNeg, numLength, (isNeg * 2) + 1);
this.temperatureValues[lineCounter++] = temperatureValue;
offset += -isNeg + numLength + this.segmentLengths[segmentIdx + 1] + 2;
}
return summaryTable.toMap();
}
private void parseLineSegments(int byteCount, boolean skipToNewline) {
@ -406,22 +384,23 @@ public class CalculateAverage_gabrielreid {
}
this.firstLineStart = idx;
this.lineCount = 0;
int segmentCounter = 0;
int lineStart = idx;
while (idx < upperBound && lineCount < numSegmentLengths) {
while (idx < upperBound) {
var byteVector = ByteVector.fromArray(BYTE_SPECIES, readBuffer, idx);
var newlineByteMask = byteVector.eq(NEWLINE_BYTE);
if (newlineByteMask.anyTrue()) {
var newlineIdx = byteVector.eq(NEWLINE_BYTE).firstTrue();
if (newlineIdx < BYTE_SPECIES_LEN) {
var semicolonByteMask = byteVector.eq(SEMICOLON_BYTE);
var semicolonIdx = semicolonByteMask.firstTrue();
int semicolonOffset = idx + semicolonIdx;
int lineEnd = idx + newlineByteMask.firstTrue();
byte negMult = (byte) ((readBuffer[idx + semicolonIdx + 1] == NEG_BYTE) ? 0b10000000 : 0);
this.segmentLengths[segmentCounter++] = (byte) (semicolonOffset - lineStart);
this.segmentLengths[segmentCounter++] = (byte) ((lineEnd - (semicolonOffset + 1)) | negMult);
this.lineCount++;
int lineEnd = idx + newlineIdx;
short negative = (short) ((readBuffer[idx + semicolonIdx + 1] == NEG_BYTE) ? 1 : 0);
int numLength = (byte) ((lineEnd - (semicolonOffset + 1)));
var num = parseNumFromLine(this.readBuffer, semicolonOffset + 1 + negative, numLength - negative);
num = negative == 1 ? (short) -num : num;
int nameLen = semicolonOffset - lineStart;
summaryTable.putTemperatureValue(this.readBuffer, lineStart, nameLen, num);
idx = lineEnd + 1;
lastLineEnd = idx;
lineStart = idx;
@ -439,19 +418,19 @@ public class CalculateAverage_gabrielreid {
}
if (readBuffer[idx] == NEWLINE_BYTE) {
int lineEnd = idx;
byte negMult = (byte) ((readBuffer[semicolonIdx + 1] == NEG_BYTE) ? 0b10000000 : 0);
this.segmentLengths[segmentCounter++] = (byte) (semicolonIdx - lineStart);
this.segmentLengths[segmentCounter++] = (byte) ((lineEnd - (semicolonIdx + 1)) | negMult);
this.lineCount++;
short negative = (short) ((readBuffer[semicolonIdx + 1] == NEG_BYTE) ? 1 : 0);
int numLength = (byte) ((lineEnd - (semicolonIdx + 1)));
var num = parseNumFromLine(this.readBuffer, semicolonIdx + 1 + negative, numLength - negative);
num = negative == 1 ? (short) -num : num;
int nameLen = semicolonIdx - lineStart;
summaryTable.putTemperatureValue(this.readBuffer, lineStart, nameLen, num);
idx = lineEnd + 1;
lastLineEnd = idx;
lineStart = idx;
}
idx++;
}
}
}
}