Attempt nicer threading via streams and spliterators

This commit is contained in:
Nick Palmer 2024-01-03 22:18:40 +00:00 committed by Gunnar Morling
parent b2cd84c6bc
commit 6aa63e1bd5

View File

@ -21,129 +21,89 @@ import java.nio.ByteBuffer;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.*; import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
public class CalculateAverage_palmr { public class CalculateAverage_palmr {
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
public static final int CHUNK_SIZE = 1024 * 1024 * 10; // Trial and error showed ~10MB to be a good size on our machine private static final int CHUNK_SIZE = 1024 * 1024 * 10; // Trial and error showed ~10MB to be a good size on our machine
public static final int LITTLE_CHUNK_SIZE = 128; // Enough bytes to cover a station name and measurement value :fingers-crossed: private static final int STATION_NAME_BUFFER_SIZE = 50;
public static final int STATION_NAME_BUFFER_SIZE = 50; private static final int THREAD_COUNT = Math.min(8, Runtime.getRuntime().availableProcessors());
public static final int THREAD_COUNT = Math.min(8, Runtime.getRuntime().availableProcessors()); private static final char SEPARATOR_CHAR = ';';
private static final char NEWLINE_CHAR = '\n';
private static final char MINUS_CHAR = '-';
private static final char DECIMAL_POINT_CHAR = '.';
public static void main(String[] args) throws IOException { public static void main(String[] args) throws IOException {
@SuppressWarnings("resource") // It's faster to leak the file than be well-behaved @SuppressWarnings("resource") // It's faster to leak the file than be well-behaved
RandomAccessFile file = new RandomAccessFile(FILE, "r"); final var file = new RandomAccessFile(FILE, "r");
FileChannel channel = file.getChannel(); final var channel = file.getChannel();
long fileSize = channel.size();
long threadChunk = fileSize / THREAD_COUNT; final TreeMap<String, MeasurementAggregator> results = StreamSupport.stream(ThreadChunk.chunk(file, THREAD_COUNT), true)
.map(chunk -> parseChunk(chunk, channel))
Thread[] threads = new Thread[THREAD_COUNT]; .flatMap(bakm -> bakm.getAsUnorderedList().stream())
ByteArrayKeyedMap[] results = new ByteArrayKeyedMap[THREAD_COUNT]; .collect(Collectors.toMap(m -> new String(m.stationNameBytes, StandardCharsets.UTF_8), m -> m, MeasurementAggregator::merge, TreeMap::new));
for (int i = 0; i < THREAD_COUNT; i++) { System.out.println(results);
final int j = i;
long startPoint = j * threadChunk;
long endPoint = startPoint + threadChunk;
Thread thread = new Thread(() -> {
try {
results[j] = readAndParse(channel, startPoint, endPoint, fileSize);
}
catch (Throwable t) {
System.err.println("It's broken :(");
// noinspection CallToPrintStackTrace
t.printStackTrace();
}
});
threads[i] = thread;
thread.start();
} }
final Map<String, MeasurementAggregator> finalAggregator = new TreeMap<>(); private record ThreadChunk(long startPoint, long endPoint, long size) {
public static Spliterator<CalculateAverage_palmr.ThreadChunk> chunk(final RandomAccessFile file, final int chunkCount) throws IOException {
final var fileSize = file.length();
final var idealChunkSize = fileSize / THREAD_COUNT;
final var chunks = new CalculateAverage_palmr.ThreadChunk[chunkCount];
for (int i = 0; i < THREAD_COUNT; i++) { var startPoint = 0L;
try { for (int i = 0; i < chunkCount; i++) {
threads[i].join(); var endPoint = Math.min(startPoint + idealChunkSize, fileSize);
file.seek(endPoint);
while (endPoint < fileSize && file.readByte() != NEWLINE_CHAR) {
endPoint++;
} }
catch (InterruptedException e) { final var actualSize = endPoint - startPoint;
throw new RuntimeException(e); chunks[i] = new CalculateAverage_palmr.ThreadChunk(startPoint, endPoint, actualSize);
startPoint += actualSize;
} }
results[i].getAsUnorderedList().forEach(v -> { return Spliterators.spliterator(chunks,
String stationName = new String(v.stationNameBytes, StandardCharsets.UTF_8); Spliterator.ORDERED |
finalAggregator.compute(stationName, (_, x) -> { Spliterator.DISTINCT |
if (x == null) { Spliterator.SORTED |
return v; Spliterator.NONNULL |
Spliterator.IMMUTABLE |
Spliterator.CONCURRENT
);
} }
else {
x.count += v.count;
x.min = Math.min(x.min, v.min);
x.max = Math.max(x.max, v.max);
x.sum += v.sum;
return x;
}
});
});
}
System.out.println(finalAggregator);
} }
private static ByteArrayKeyedMap readAndParse(final FileChannel channel, private static ByteArrayKeyedMap parseChunk(ThreadChunk chunk, FileChannel channel) {
final long startPoint, final var state = new State();
final long endPoint,
final long fileSize) {
final State state = new State();
boolean skipFirstEntry = startPoint != 0; var offset = chunk.startPoint;
while (offset < chunk.endPoint) {
long offset = startPoint; parseData(channel, state, offset, Math.min(CHUNK_SIZE, chunk.endPoint - offset));
while (offset < endPoint) {
parseData(channel, state, offset, Math.min(CHUNK_SIZE, fileSize - offset), false, skipFirstEntry);
skipFirstEntry = false;
offset += CHUNK_SIZE; offset += CHUNK_SIZE;
} }
if (offset < fileSize) {
// Make sure we finish reading any partially read entry by going a little in to the next chunk, stopping at the first newline
parseData(channel, state, offset, Math.min(LITTLE_CHUNK_SIZE, fileSize - offset), true, false);
}
return state.aggregators; return state.aggregators;
} }
private static void parseData(final FileChannel channel, private static void parseData(final FileChannel channel,
final State state, final State state,
final long offset, final long offset,
final long bufferSize, final long bufferSize) {
final boolean stopAtNewline, final ByteBuffer byteBuffer;
final boolean skipFirstEntry) {
ByteBuffer byteBuffer;
try { try {
byteBuffer = channel.map(FileChannel.MapMode.READ_ONLY, offset, bufferSize); byteBuffer = channel.map(FileChannel.MapMode.READ_ONLY, offset, bufferSize);
}
catch (IOException e) {
throw new RuntimeException(e);
}
boolean isSkippingToFirstCleanEntry = skipFirstEntry;
while (byteBuffer.hasRemaining()) { while (byteBuffer.hasRemaining()) {
byte currentChar = byteBuffer.get(); final var currentChar = byteBuffer.get();
if (isSkippingToFirstCleanEntry) { if (currentChar == SEPARATOR_CHAR) {
if (currentChar == '\n') {
isSkippingToFirstCleanEntry = false;
}
continue;
}
if (currentChar == ';') {
state.parsingValue = true; state.parsingValue = true;
} } else if (currentChar == NEWLINE_CHAR) {
else if (currentChar == '\n') {
if (state.stationPointerEnd != 0) { if (state.stationPointerEnd != 0) {
double value = state.measurementValue * state.exponent; final var value = state.measurementValue * state.exponent;
MeasurementAggregator aggregator = state.aggregators.computeIfAbsent(state.stationBuffer, state.stationPointerEnd, state.signedHashCode); MeasurementAggregator aggregator = state.aggregators.computeIfAbsent(state.stationBuffer, state.stationPointerEnd, state.signedHashCode);
aggregator.count++; aggregator.count++;
@ -152,31 +112,27 @@ public class CalculateAverage_palmr {
aggregator.sum += value; aggregator.sum += value;
} }
if (stopAtNewline) {
return;
}
// reset // reset
state.reset(); state.reset();
} } else {
else {
if (!state.parsingValue) { if (!state.parsingValue) {
state.stationBuffer[state.stationPointerEnd++] = currentChar; state.stationBuffer[state.stationPointerEnd++] = currentChar;
state.signedHashCode = 31 * state.signedHashCode + (currentChar & 0xff); state.signedHashCode = 31 * state.signedHashCode + (currentChar & 0xff);
} } else {
else { if (currentChar == MINUS_CHAR) {
if (currentChar == '-') {
state.exponent = -0.1; state.exponent = -0.1;
} } else if (currentChar != DECIMAL_POINT_CHAR) {
else if (currentChar != '.') {
state.measurementValue = state.measurementValue * 10 + (currentChar - '0'); state.measurementValue = state.measurementValue * 10 + (currentChar - '0');
} }
} }
} }
} }
} catch (IOException e) {
throw new RuntimeException(e);
}
} }
static final class State { private static final class State {
ByteArrayKeyedMap aggregators = new ByteArrayKeyedMap(); ByteArrayKeyedMap aggregators = new ByteArrayKeyedMap();
boolean parsingValue = false; boolean parsingValue = false;
byte[] stationBuffer = new byte[STATION_NAME_BUFFER_SIZE]; byte[] stationBuffer = new byte[STATION_NAME_BUFFER_SIZE];
@ -208,37 +164,51 @@ public class CalculateAverage_palmr {
} }
public String toString() { public String toString() {
return round(min) + "/" + round(sum / count) + "/" + round(max); return STR."\{round(min)}/\{round(sum / count)}/\{round(max)}";
} }
private double round(double value) { private double round(final double value) {
return Math.round(value * 10.0) / 10.0; return Math.round(value * 10.0) / 10.0;
} }
private MeasurementAggregator merge(final MeasurementAggregator b) {
this.count += b.count;
this.min = Math.min(this.min, b.min);
this.max = Math.max(this.max, b.max);
this.sum += b.sum;
return this;
}
} }
/**
* Very basic hash table implementation, only implementing computeIfAbsent since that's all the code needs.
* It's sized to give minimal collisions with the example test set. this may not hold true if the stations list
* changes, but it should still perform fairly well.
* It uses Open Addressing, meaning it's just one array, rather Separate Chaining which is what the default java HashMap uses.
* IT also uses Linear probing for collision resolution, which given the minimal collision count should hold up well.
*/
private static class ByteArrayKeyedMap { private static class ByteArrayKeyedMap {
private final int BUCKET_COUNT = 0xFFF; // 413 unique stations in the data set, & 0xFFF ~= 399 (only 14 collisions (given our hashcode implementation)) private final int BUCKET_COUNT = 0xFFF; // 413 unique stations in the data set, & 0xFFF ~= 399 (only 14 collisions (given our hashcode implementation))
private final MeasurementAggregator[] buckets = new MeasurementAggregator[BUCKET_COUNT + 1]; private final MeasurementAggregator[] buckets = new MeasurementAggregator[BUCKET_COUNT + 1];
private final List<MeasurementAggregator> compactUnorderedBuckets = new ArrayList<>(413); private final List<MeasurementAggregator> compactUnorderedBuckets = new ArrayList<>(413);
public MeasurementAggregator computeIfAbsent(final byte[] key, final int keyLength, final int keyHashCode) { public MeasurementAggregator computeIfAbsent(final byte[] key, final int keyLength, final int keyHashCode) {
int index = keyHashCode & BUCKET_COUNT; var index = keyHashCode & BUCKET_COUNT;
while (true) { while (true) {
MeasurementAggregator maybe = buckets[index]; MeasurementAggregator maybe = buckets[index];
if (maybe == null) { if (maybe != null) {
final byte[] copiedKey = Arrays.copyOf(key, keyLength);
MeasurementAggregator measurementAggregator = new MeasurementAggregator(copiedKey, keyHashCode);
buckets[index] = measurementAggregator;
compactUnorderedBuckets.add(measurementAggregator);
return measurementAggregator;
}
else {
if (Arrays.equals(key, 0, keyLength, maybe.stationNameBytes, 0, maybe.stationNameBytes.length)) { if (Arrays.equals(key, 0, keyLength, maybe.stationNameBytes, 0, maybe.stationNameBytes.length)) {
return maybe; return maybe;
} }
index++; index++;
index &= BUCKET_COUNT; index &= BUCKET_COUNT;
} else {
final var copiedKey = Arrays.copyOf(key, keyLength);
MeasurementAggregator measurementAggregator = new MeasurementAggregator(copiedKey, keyHashCode);
buckets[index] = measurementAggregator;
compactUnorderedBuckets.add(measurementAggregator);
return measurementAggregator;
} }
} }
} }