Update seijikun implementation
* Use Integer calculation instead of double, add unit-test * Bring back StationIdent optimization Originally, StationIdent was using byte[] to store names, so the extra String allocation could be avoided. However, that produced incorrect sorting. Sorting is now moved to the result merging step. Here, names are converted to Strings. * Implement readStationName with SIMD 256bit * Rebase and cleanup test code, now that it's in the project * Fix seijikun formatting * Fix test failure in specific jobCnt edge-cases * Also switch to graalvm
This commit is contained in:
parent
e3f6c3aaf7
commit
36dac255cf
@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
JAVA_OPTS="-XX:+UseParallelGC --enable-preview --add-modules jdk.incubator.vector"
|
||||||
JAVA_OPTS="--enable-preview"
|
source "$HOME/.sdkman/bin/sdkman-init.sh"
|
||||||
|
sdk use java 21.0.1-graal 1>&2
|
||||||
time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_seijikun
|
time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_seijikun
|
||||||
|
@ -15,9 +15,18 @@
|
|||||||
*/
|
*/
|
||||||
package dev.morling.onebrc;
|
package dev.morling.onebrc;
|
||||||
|
|
||||||
import java.io.*;
|
import jdk.incubator.vector.ByteVector;
|
||||||
|
import jdk.incubator.vector.VectorOperators;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.PrintStream;
|
||||||
|
import java.io.RandomAccessFile;
|
||||||
|
import java.lang.foreign.MemorySegment;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
import java.nio.MappedByteBuffer;
|
import java.nio.MappedByteBuffer;
|
||||||
import java.nio.channels.FileChannel;
|
import java.nio.channels.FileChannel;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
@ -27,24 +36,36 @@ public class CalculateAverage_seijikun {
|
|||||||
private static final String FILE = "./measurements.txt";
|
private static final String FILE = "./measurements.txt";
|
||||||
|
|
||||||
private static class MeasurementAggregator {
|
private static class MeasurementAggregator {
|
||||||
private double min = Double.POSITIVE_INFINITY;
|
private int min = Integer.MAX_VALUE;
|
||||||
private double max = Double.NEGATIVE_INFINITY;
|
private int max = Integer.MIN_VALUE;
|
||||||
private double sum;
|
// final long startTs = System.currentTimeMillis();
|
||||||
private long count;
|
private long sum = 0;
|
||||||
|
private long count = 0;
|
||||||
|
|
||||||
|
private double mean = 0;
|
||||||
|
|
||||||
|
public void finish() {
|
||||||
|
double sum = this.sum / 10.0;
|
||||||
|
mean = sum / (double) count;
|
||||||
|
}
|
||||||
|
|
||||||
public void printInto(PrintStream out) {
|
public void printInto(PrintStream out) {
|
||||||
out.printf("%.1f/%.1f/%.1f", min, (sum / (double) count), max);
|
double min = (double) this.min / 10.0;
|
||||||
|
double max = (double) this.max / 10.0;
|
||||||
|
out.printf("%.1f/%.1f/%.1f", min, mean, max);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class StationIdent implements Comparable<StationIdent> {
|
public static class StationIdent {
|
||||||
private final int nameLength;
|
private final byte[] name;
|
||||||
private final String name;
|
|
||||||
private final int nameHash;
|
private final int nameHash;
|
||||||
|
|
||||||
public StationIdent(byte[] name, int nameHash) {
|
public StationIdent(byte[] name, int nameHash) {
|
||||||
this.nameLength = name.length;
|
this.name = name;
|
||||||
this.name = new String(name);
|
// TODO: DEBUG
|
||||||
|
// if(Arrays.asList(this.name).contains(';')) {
|
||||||
|
// throw new RuntimeException();
|
||||||
|
// }
|
||||||
this.nameHash = nameHash;
|
this.nameHash = nameHash;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,15 +77,10 @@ public class CalculateAverage_seijikun {
|
|||||||
@Override
|
@Override
|
||||||
public boolean equals(Object obj) {
|
public boolean equals(Object obj) {
|
||||||
var other = (StationIdent) obj;
|
var other = (StationIdent) obj;
|
||||||
if (other.nameLength != nameLength) {
|
if (other.name.length != name.length) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return name.equals(other.name);
|
return Arrays.equals(name, other.name);
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int compareTo(StationIdent o) {
|
|
||||||
return name.compareTo(o.name);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,9 +93,11 @@ public class CalculateAverage_seijikun {
|
|||||||
private final long endOffset;
|
private final long endOffset;
|
||||||
|
|
||||||
// state
|
// state
|
||||||
|
private int chunkSize = 0;
|
||||||
private MappedByteBuffer buffer = null;
|
private MappedByteBuffer buffer = null;
|
||||||
|
private MemorySegment memorySegment = null;
|
||||||
private int ptr = 0;
|
private int ptr = 0;
|
||||||
private TreeMap<StationIdent, MeasurementAggregator> workSet;
|
private HashMap<StationIdent, MeasurementAggregator> workSet;
|
||||||
|
|
||||||
public ChunkReader(RandomAccessFile file, long startOffset, long endOffset) {
|
public ChunkReader(RandomAccessFile file, long startOffset, long endOffset) {
|
||||||
this.file = file;
|
this.file = file;
|
||||||
@ -87,36 +105,67 @@ public class CalculateAverage_seijikun {
|
|||||||
this.endOffset = endOffset;
|
this.endOffset = endOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// private StationIdent readStationName() {
|
||||||
|
// int startPtr = ptr;
|
||||||
|
// int hashCode = 0;
|
||||||
|
// int hashBytePtr = 0;
|
||||||
|
// byte c;
|
||||||
|
// while ((c = buffer.get(ptr++)) != ';') {
|
||||||
|
// hashCode ^= ((int) c) << (hashBytePtr * 8);
|
||||||
|
// hashBytePtr = (hashBytePtr + 1) % 4;
|
||||||
|
// }
|
||||||
|
// byte[] stationNameBfr = new byte[ptr - startPtr - 1];
|
||||||
|
// buffer.get(startPtr, stationNameBfr);
|
||||||
|
// return new StationIdent(stationNameBfr, hashCode);
|
||||||
|
// }
|
||||||
|
|
||||||
private StationIdent readStationName() {
|
private StationIdent readStationName() {
|
||||||
int startPtr = ptr;
|
final var VECTOR_SPECIES = ByteVector.SPECIES_256;
|
||||||
int hashCode = 0;
|
|
||||||
int hashBytePtr = 0;
|
if (chunkSize - ptr < VECTOR_SPECIES.length()) { // fallback
|
||||||
byte c;
|
int startPtr = ptr;
|
||||||
while ((c = buffer.get(ptr++)) != ';') {
|
while (buffer.get(ptr++) != ';') {
|
||||||
hashCode ^= ((int) c) << (hashBytePtr * 8);
|
}
|
||||||
hashBytePtr = (hashBytePtr + 1) % 4;
|
byte[] stationNameBfr = new byte[ptr - startPtr - 1];
|
||||||
|
buffer.get(startPtr, stationNameBfr);
|
||||||
|
return new StationIdent(stationNameBfr, Arrays.hashCode(stationNameBfr) ^ stationNameBfr.length);
|
||||||
|
}
|
||||||
|
else { // SIMD
|
||||||
|
int sepIdx = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
ByteVector tmp = ByteVector.fromMemorySegment(VECTOR_SPECIES, memorySegment, ptr + sepIdx, ByteOrder.LITTLE_ENDIAN);
|
||||||
|
final var cmpResult = tmp.compare(VectorOperators.EQ, ';');
|
||||||
|
if (cmpResult.anyTrue()) {
|
||||||
|
sepIdx += cmpResult.firstTrue();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
sepIdx += tmp.length();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int endPtr = ptr + sepIdx;
|
||||||
|
byte[] stationNameBfr = new byte[endPtr - ptr];
|
||||||
|
buffer.get(ptr, stationNameBfr);
|
||||||
|
ptr = endPtr + 1;
|
||||||
|
return new StationIdent(stationNameBfr, Arrays.hashCode(stationNameBfr) ^ stationNameBfr.length);
|
||||||
}
|
}
|
||||||
byte[] stationNameBfr = new byte[ptr - startPtr - 1];
|
|
||||||
buffer.get(startPtr, stationNameBfr);
|
|
||||||
return new StationIdent(stationNameBfr, hashCode);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private double readTemperature() {
|
private int readTemperature() {
|
||||||
double ret = 0, div = 1;
|
int ret = 0;
|
||||||
byte c = buffer.get(ptr++);
|
byte c = buffer.get(ptr++);
|
||||||
boolean neg = (c == '-');
|
final boolean neg = (c == '-');
|
||||||
if (neg)
|
if (neg) {
|
||||||
c = buffer.get(ptr++);
|
c = buffer.get(ptr++);
|
||||||
|
}
|
||||||
|
|
||||||
do {
|
do {
|
||||||
ret = ret * 10 + c - '0';
|
if (c != '.') {
|
||||||
} while ((c = buffer.get(ptr++)) >= '0' && c <= '9');
|
ret = ret * 10 + c - '0';
|
||||||
|
|
||||||
if (c == '.') {
|
|
||||||
while ((c = buffer.get(ptr++)) != '\n') {
|
|
||||||
ret += (c - '0') / (div *= 10);
|
|
||||||
}
|
}
|
||||||
}
|
} while ((c = buffer.get(ptr++)) != '\n');
|
||||||
|
|
||||||
if (neg)
|
if (neg)
|
||||||
return -ret;
|
return -ret;
|
||||||
@ -125,14 +174,18 @@ public class CalculateAverage_seijikun {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
workSet = new TreeMap<>();
|
workSet = new HashMap<>();
|
||||||
int chunkSize = (int) (endOffset - startOffset);
|
if (endOffset - startOffset > Integer.MAX_VALUE) {
|
||||||
|
throw new RuntimeException("Mapping a block larger than 2GB is not possible with Java! Welcome to 2024 :)");
|
||||||
|
}
|
||||||
|
chunkSize = (int) (endOffset - startOffset);
|
||||||
try {
|
try {
|
||||||
buffer = file.getChannel().map(FileChannel.MapMode.READ_ONLY, startOffset, chunkSize);
|
buffer = file.getChannel().map(FileChannel.MapMode.READ_ONLY, startOffset, chunkSize);
|
||||||
|
memorySegment = MemorySegment.ofBuffer(buffer);
|
||||||
|
|
||||||
while (ptr < chunkSize) {
|
while (ptr < chunkSize) {
|
||||||
var station = readStationName();
|
var station = readStationName();
|
||||||
var temp = readTemperature();
|
int temp = readTemperature();
|
||||||
var stationWorkSet = workSet.get(station);
|
var stationWorkSet = workSet.get(station);
|
||||||
if (stationWorkSet == null) {
|
if (stationWorkSet == null) {
|
||||||
stationWorkSet = new MeasurementAggregator();
|
stationWorkSet = new MeasurementAggregator();
|
||||||
@ -144,26 +197,42 @@ public class CalculateAverage_seijikun {
|
|||||||
stationWorkSet.count += 1;
|
stationWorkSet.count += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (IOException e) {
|
catch (Throwable e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException, InterruptedException {
|
private static void printWorkSet(TreeMap<String, MeasurementAggregator> result, PrintStream out) {
|
||||||
RandomAccessFile file = new RandomAccessFile(FILE, "r");
|
out.write('{');
|
||||||
|
final var iterator = result.entrySet().iterator();
|
||||||
|
while (iterator.hasNext()) {
|
||||||
|
var entry = iterator.next();
|
||||||
|
out.print(entry.getKey());
|
||||||
|
out.write('=');
|
||||||
|
entry.getValue().printInto(out);
|
||||||
|
if (iterator.hasNext()) {
|
||||||
|
out.print(", ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out.println('}');
|
||||||
|
}
|
||||||
|
|
||||||
int jobCnt = Runtime.getRuntime().availableProcessors();
|
private static int createChunks(final RandomAccessFile file, final ChunkReader[] chunks) throws IOException {
|
||||||
|
final long fileEndPtr = file.length();
|
||||||
|
final long chunkSize = Math.max(1, fileEndPtr / chunks.length);
|
||||||
|
|
||||||
var chunks = new ChunkReader[jobCnt];
|
int jobCnt = 0;
|
||||||
long chunkSize = file.length() / jobCnt;
|
|
||||||
long chunkStartPtr = 0;
|
long chunkStartPtr = 0;
|
||||||
byte[] tmpBuffer = new byte[128];
|
final byte[] tmpBuffer = new byte[128];
|
||||||
for (int i = 0; i < jobCnt; ++i) {
|
while (chunkStartPtr < fileEndPtr) {
|
||||||
long chunkEndPtr = chunkStartPtr + chunkSize;
|
long chunkEndPtr = Math.min(chunkStartPtr + chunkSize, fileEndPtr);
|
||||||
if (i != (jobCnt - 1)) { // align chunks to newlines
|
|
||||||
file.seek(chunkEndPtr - 1);
|
// Seek into file at the calculated chunk end ptr, then extend it until the next
|
||||||
|
// new-line or EOF
|
||||||
|
if (chunkEndPtr < fileEndPtr) {
|
||||||
|
file.seek(Math.max(0, chunkEndPtr - 1));
|
||||||
file.read(tmpBuffer);
|
file.read(tmpBuffer);
|
||||||
int offset = 0;
|
int offset = 0;
|
||||||
while (tmpBuffer[offset] != '\n') {
|
while (tmpBuffer[offset] != '\n') {
|
||||||
@ -171,28 +240,38 @@ public class CalculateAverage_seijikun {
|
|||||||
}
|
}
|
||||||
chunkEndPtr += offset;
|
chunkEndPtr += offset;
|
||||||
}
|
}
|
||||||
else { // last chunk ends at file end
|
|
||||||
chunkEndPtr = file.length();
|
chunks[jobCnt] = new ChunkReader(file, chunkStartPtr, chunkEndPtr);
|
||||||
}
|
jobCnt += 1;
|
||||||
chunks[i] = new ChunkReader(file, chunkStartPtr, chunkEndPtr);
|
|
||||||
chunkStartPtr = chunkEndPtr;
|
chunkStartPtr = chunkEndPtr;
|
||||||
}
|
}
|
||||||
|
return jobCnt;
|
||||||
|
}
|
||||||
|
|
||||||
try (var executor = Executors.newFixedThreadPool(jobCnt)) {
|
public static void main(String[] args) throws IOException, InterruptedException {
|
||||||
|
final RandomAccessFile file = new RandomAccessFile(FILE, "r");
|
||||||
|
|
||||||
|
int jobCnt = Runtime.getRuntime().availableProcessors();
|
||||||
|
|
||||||
|
final var chunks = new ChunkReader[jobCnt];
|
||||||
|
jobCnt = createChunks(file, chunks);
|
||||||
|
|
||||||
|
try (final var executor = Executors.newFixedThreadPool(jobCnt)) {
|
||||||
for (int i = 0; i < jobCnt; ++i) {
|
for (int i = 0; i < jobCnt; ++i) {
|
||||||
executor.submit(chunks[i]);
|
executor.submit(chunks[i]);
|
||||||
}
|
}
|
||||||
executor.shutdown();
|
executor.shutdown();
|
||||||
var ignored = executor.awaitTermination(1, TimeUnit.DAYS);
|
final var ignored = executor.awaitTermination(1, TimeUnit.DAYS);
|
||||||
}
|
}
|
||||||
|
|
||||||
// merge chunks
|
// merge chunks
|
||||||
var result = chunks[0].workSet;
|
final var result = new TreeMap<String, MeasurementAggregator>();
|
||||||
for (int i = 1; i < jobCnt; ++i) {
|
for (int i = 0; i < jobCnt; ++i) {
|
||||||
chunks[i].workSet.forEach((ident, otherStationWorkSet) -> {
|
chunks[i].workSet.forEach((ident, otherStationWorkSet) -> {
|
||||||
var stationWorkSet = result.get(ident);
|
final var identStr = new String(ident.name);
|
||||||
|
final var stationWorkSet = result.get(identStr);
|
||||||
if (stationWorkSet == null) {
|
if (stationWorkSet == null) {
|
||||||
result.put(ident, otherStationWorkSet);
|
result.put(identStr, otherStationWorkSet);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
stationWorkSet.min = Math.min(stationWorkSet.min, otherStationWorkSet.min);
|
stationWorkSet.min = Math.min(stationWorkSet.min, otherStationWorkSet.min);
|
||||||
@ -202,19 +281,9 @@ public class CalculateAverage_seijikun {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
result.forEach((ignored, meas) -> meas.finish());
|
||||||
|
|
||||||
// print in required format
|
// print in required format
|
||||||
System.out.write('{');
|
printWorkSet(result, System.out);
|
||||||
var iterator = result.entrySet().iterator();
|
|
||||||
while (iterator.hasNext()) {
|
|
||||||
var entry = iterator.next();
|
|
||||||
System.out.print(entry.getKey().name);
|
|
||||||
System.out.write('=');
|
|
||||||
entry.getValue().printInto(System.out);
|
|
||||||
if (iterator.hasNext()) {
|
|
||||||
System.out.print(", ");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
System.out.println('}');
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user