From 36dac255cf2b36811c5fc9b6fc9ca37e17bc34b6 Mon Sep 17 00:00:00 2001 From: Markus Ebner Date: Fri, 5 Jan 2024 19:35:15 +0100 Subject: [PATCH] Update seijikun implementation * Use Integer calculation instead of double, add unit-test * Bring back StationIdent optimization Originally, StationIdent was using byte[] to store names, so the extra String allocation could be avoided. However, that produced incorrect sorting. Sorting is now moved to the result merging step. Here, names are converted to Strings. * Implement readStationName with SIMD 256bit * Rebase and cleanup test code, now that it's in the project * Fix seijikun formatting * Fix test failure in specific jobCnt edge-cases * Also switch to graalvm --- calculate_average_seijikun.sh | 5 +- .../onebrc/CalculateAverage_seijikun.java | 221 ++++++++++++------ 2 files changed, 148 insertions(+), 78 deletions(-) diff --git a/calculate_average_seijikun.sh b/calculate_average_seijikun.sh index fbe68ad..5469b61 100755 --- a/calculate_average_seijikun.sh +++ b/calculate_average_seijikun.sh @@ -15,6 +15,7 @@ # limitations under the License. # - -JAVA_OPTS="--enable-preview" +JAVA_OPTS="-XX:+UseParallelGC --enable-preview --add-modules jdk.incubator.vector" +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_seijikun diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_seijikun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_seijikun.java index bdea518..c5678b1 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_seijikun.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_seijikun.java @@ -15,9 +15,18 @@ */ package dev.morling.onebrc; -import java.io.*; +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.VectorOperators; + +import java.io.IOException; +import java.io.PrintStream; +import java.io.RandomAccessFile; +import java.lang.foreign.MemorySegment; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; +import java.util.Arrays; +import java.util.HashMap; import java.util.TreeMap; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -27,24 +36,36 @@ public class CalculateAverage_seijikun { private static final String FILE = "./measurements.txt"; private static class MeasurementAggregator { - private double min = Double.POSITIVE_INFINITY; - private double max = Double.NEGATIVE_INFINITY; - private double sum; - private long count; + private int min = Integer.MAX_VALUE; + private int max = Integer.MIN_VALUE; + // final long startTs = System.currentTimeMillis(); + private long sum = 0; + private long count = 0; + + private double mean = 0; + + public void finish() { + double sum = this.sum / 10.0; + mean = sum / (double) count; + } public void printInto(PrintStream out) { - out.printf("%.1f/%.1f/%.1f", min, (sum / (double) count), max); + double min = (double) this.min / 10.0; + double max = (double) this.max / 10.0; + out.printf("%.1f/%.1f/%.1f", min, mean, max); } } - public static class StationIdent implements Comparable { - private final int nameLength; - private final String name; + public static class StationIdent { + private final byte[] name; private final int nameHash; public StationIdent(byte[] name, int nameHash) { - this.nameLength = name.length; - this.name = new String(name); + this.name = name; + // TODO: DEBUG + // if(Arrays.asList(this.name).contains(';')) { + // throw new RuntimeException(); + // } this.nameHash = nameHash; } @@ -56,15 +77,10 @@ public class CalculateAverage_seijikun { @Override public boolean equals(Object obj) { var other = (StationIdent) obj; - if (other.nameLength != nameLength) { + if (other.name.length != name.length) { return false; } - return name.equals(other.name); - } - - @Override - public int compareTo(StationIdent o) { - return name.compareTo(o.name); + return Arrays.equals(name, other.name); } } @@ -77,9 +93,11 @@ public class CalculateAverage_seijikun { private final long endOffset; // state + private int chunkSize = 0; private MappedByteBuffer buffer = null; + private MemorySegment memorySegment = null; private int ptr = 0; - private TreeMap workSet; + private HashMap workSet; public ChunkReader(RandomAccessFile file, long startOffset, long endOffset) { this.file = file; @@ -87,36 +105,67 @@ public class CalculateAverage_seijikun { this.endOffset = endOffset; } + // private StationIdent readStationName() { + // int startPtr = ptr; + // int hashCode = 0; + // int hashBytePtr = 0; + // byte c; + // while ((c = buffer.get(ptr++)) != ';') { + // hashCode ^= ((int) c) << (hashBytePtr * 8); + // hashBytePtr = (hashBytePtr + 1) % 4; + // } + // byte[] stationNameBfr = new byte[ptr - startPtr - 1]; + // buffer.get(startPtr, stationNameBfr); + // return new StationIdent(stationNameBfr, hashCode); + // } + private StationIdent readStationName() { - int startPtr = ptr; - int hashCode = 0; - int hashBytePtr = 0; - byte c; - while ((c = buffer.get(ptr++)) != ';') { - hashCode ^= ((int) c) << (hashBytePtr * 8); - hashBytePtr = (hashBytePtr + 1) % 4; + final var VECTOR_SPECIES = ByteVector.SPECIES_256; + + if (chunkSize - ptr < VECTOR_SPECIES.length()) { // fallback + int startPtr = ptr; + while (buffer.get(ptr++) != ';') { + } + byte[] stationNameBfr = new byte[ptr - startPtr - 1]; + buffer.get(startPtr, stationNameBfr); + return new StationIdent(stationNameBfr, Arrays.hashCode(stationNameBfr) ^ stationNameBfr.length); + } + else { // SIMD + int sepIdx = 0; + + while (true) { + ByteVector tmp = ByteVector.fromMemorySegment(VECTOR_SPECIES, memorySegment, ptr + sepIdx, ByteOrder.LITTLE_ENDIAN); + final var cmpResult = tmp.compare(VectorOperators.EQ, ';'); + if (cmpResult.anyTrue()) { + sepIdx += cmpResult.firstTrue(); + break; + } + else { + sepIdx += tmp.length(); + } + } + + int endPtr = ptr + sepIdx; + byte[] stationNameBfr = new byte[endPtr - ptr]; + buffer.get(ptr, stationNameBfr); + ptr = endPtr + 1; + return new StationIdent(stationNameBfr, Arrays.hashCode(stationNameBfr) ^ stationNameBfr.length); } - byte[] stationNameBfr = new byte[ptr - startPtr - 1]; - buffer.get(startPtr, stationNameBfr); - return new StationIdent(stationNameBfr, hashCode); } - private double readTemperature() { - double ret = 0, div = 1; + private int readTemperature() { + int ret = 0; byte c = buffer.get(ptr++); - boolean neg = (c == '-'); - if (neg) + final boolean neg = (c == '-'); + if (neg) { c = buffer.get(ptr++); + } do { - ret = ret * 10 + c - '0'; - } while ((c = buffer.get(ptr++)) >= '0' && c <= '9'); - - if (c == '.') { - while ((c = buffer.get(ptr++)) != '\n') { - ret += (c - '0') / (div *= 10); + if (c != '.') { + ret = ret * 10 + c - '0'; } - } + } while ((c = buffer.get(ptr++)) != '\n'); if (neg) return -ret; @@ -125,14 +174,18 @@ public class CalculateAverage_seijikun { @Override public void run() { - workSet = new TreeMap<>(); - int chunkSize = (int) (endOffset - startOffset); + workSet = new HashMap<>(); + if (endOffset - startOffset > Integer.MAX_VALUE) { + throw new RuntimeException("Mapping a block larger than 2GB is not possible with Java! Welcome to 2024 :)"); + } + chunkSize = (int) (endOffset - startOffset); try { buffer = file.getChannel().map(FileChannel.MapMode.READ_ONLY, startOffset, chunkSize); + memorySegment = MemorySegment.ofBuffer(buffer); while (ptr < chunkSize) { var station = readStationName(); - var temp = readTemperature(); + int temp = readTemperature(); var stationWorkSet = workSet.get(station); if (stationWorkSet == null) { stationWorkSet = new MeasurementAggregator(); @@ -144,26 +197,42 @@ public class CalculateAverage_seijikun { stationWorkSet.count += 1; } } - catch (IOException e) { + catch (Throwable e) { e.printStackTrace(); throw new RuntimeException(e); } } } - public static void main(String[] args) throws IOException, InterruptedException { - RandomAccessFile file = new RandomAccessFile(FILE, "r"); + private static void printWorkSet(TreeMap result, PrintStream out) { + out.write('{'); + final var iterator = result.entrySet().iterator(); + while (iterator.hasNext()) { + var entry = iterator.next(); + out.print(entry.getKey()); + out.write('='); + entry.getValue().printInto(out); + if (iterator.hasNext()) { + out.print(", "); + } + } + out.println('}'); + } - int jobCnt = Runtime.getRuntime().availableProcessors(); + private static int createChunks(final RandomAccessFile file, final ChunkReader[] chunks) throws IOException { + final long fileEndPtr = file.length(); + final long chunkSize = Math.max(1, fileEndPtr / chunks.length); - var chunks = new ChunkReader[jobCnt]; - long chunkSize = file.length() / jobCnt; + int jobCnt = 0; long chunkStartPtr = 0; - byte[] tmpBuffer = new byte[128]; - for (int i = 0; i < jobCnt; ++i) { - long chunkEndPtr = chunkStartPtr + chunkSize; - if (i != (jobCnt - 1)) { // align chunks to newlines - file.seek(chunkEndPtr - 1); + final byte[] tmpBuffer = new byte[128]; + while (chunkStartPtr < fileEndPtr) { + long chunkEndPtr = Math.min(chunkStartPtr + chunkSize, fileEndPtr); + + // Seek into file at the calculated chunk end ptr, then extend it until the next + // new-line or EOF + if (chunkEndPtr < fileEndPtr) { + file.seek(Math.max(0, chunkEndPtr - 1)); file.read(tmpBuffer); int offset = 0; while (tmpBuffer[offset] != '\n') { @@ -171,28 +240,38 @@ public class CalculateAverage_seijikun { } chunkEndPtr += offset; } - else { // last chunk ends at file end - chunkEndPtr = file.length(); - } - chunks[i] = new ChunkReader(file, chunkStartPtr, chunkEndPtr); + + chunks[jobCnt] = new ChunkReader(file, chunkStartPtr, chunkEndPtr); + jobCnt += 1; chunkStartPtr = chunkEndPtr; } + return jobCnt; + } - try (var executor = Executors.newFixedThreadPool(jobCnt)) { + public static void main(String[] args) throws IOException, InterruptedException { + final RandomAccessFile file = new RandomAccessFile(FILE, "r"); + + int jobCnt = Runtime.getRuntime().availableProcessors(); + + final var chunks = new ChunkReader[jobCnt]; + jobCnt = createChunks(file, chunks); + + try (final var executor = Executors.newFixedThreadPool(jobCnt)) { for (int i = 0; i < jobCnt; ++i) { executor.submit(chunks[i]); } executor.shutdown(); - var ignored = executor.awaitTermination(1, TimeUnit.DAYS); + final var ignored = executor.awaitTermination(1, TimeUnit.DAYS); } // merge chunks - var result = chunks[0].workSet; - for (int i = 1; i < jobCnt; ++i) { + final var result = new TreeMap(); + for (int i = 0; i < jobCnt; ++i) { chunks[i].workSet.forEach((ident, otherStationWorkSet) -> { - var stationWorkSet = result.get(ident); + final var identStr = new String(ident.name); + final var stationWorkSet = result.get(identStr); if (stationWorkSet == null) { - result.put(ident, otherStationWorkSet); + result.put(identStr, otherStationWorkSet); } else { stationWorkSet.min = Math.min(stationWorkSet.min, otherStationWorkSet.min); @@ -202,19 +281,9 @@ public class CalculateAverage_seijikun { } }); } + result.forEach((ignored, meas) -> meas.finish()); // print in required format - System.out.write('{'); - var iterator = result.entrySet().iterator(); - while (iterator.hasNext()) { - var entry = iterator.next(); - System.out.print(entry.getKey().name); - System.out.write('='); - entry.getValue().printInto(System.out); - if (iterator.hasNext()) { - System.out.print(", "); - } - } - System.out.println('}'); + printWorkSet(result, System.out); } }