Update seijikun implementation
* Use Integer calculation instead of double, add unit-test * Bring back StationIdent optimization Originally, StationIdent was using byte[] to store names, so the extra String allocation could be avoided. However, that produced incorrect sorting. Sorting is now moved to the result merging step. Here, names are converted to Strings. * Implement readStationName with SIMD 256bit * Rebase and cleanup test code, now that it's in the project * Fix seijikun formatting * Fix test failure in specific jobCnt edge-cases * Also switch to graalvm
This commit is contained in:
		@@ -15,9 +15,18 @@
 | 
			
		||||
 */
 | 
			
		||||
package dev.morling.onebrc;
 | 
			
		||||
 | 
			
		||||
import java.io.*;
 | 
			
		||||
import jdk.incubator.vector.ByteVector;
 | 
			
		||||
import jdk.incubator.vector.VectorOperators;
 | 
			
		||||
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.io.PrintStream;
 | 
			
		||||
import java.io.RandomAccessFile;
 | 
			
		||||
import java.lang.foreign.MemorySegment;
 | 
			
		||||
import java.nio.ByteOrder;
 | 
			
		||||
import java.nio.MappedByteBuffer;
 | 
			
		||||
import java.nio.channels.FileChannel;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.TreeMap;
 | 
			
		||||
import java.util.concurrent.Executors;
 | 
			
		||||
import java.util.concurrent.TimeUnit;
 | 
			
		||||
@@ -27,24 +36,36 @@ public class CalculateAverage_seijikun {
 | 
			
		||||
    private static final String FILE = "./measurements.txt";
 | 
			
		||||
 | 
			
		||||
    private static class MeasurementAggregator {
 | 
			
		||||
        private double min = Double.POSITIVE_INFINITY;
 | 
			
		||||
        private double max = Double.NEGATIVE_INFINITY;
 | 
			
		||||
        private double sum;
 | 
			
		||||
        private long count;
 | 
			
		||||
        private int min = Integer.MAX_VALUE;
 | 
			
		||||
        private int max = Integer.MIN_VALUE;
 | 
			
		||||
        // final long startTs = System.currentTimeMillis();
 | 
			
		||||
        private long sum = 0;
 | 
			
		||||
        private long count = 0;
 | 
			
		||||
 | 
			
		||||
        private double mean = 0;
 | 
			
		||||
 | 
			
		||||
        public void finish() {
 | 
			
		||||
            double sum = this.sum / 10.0;
 | 
			
		||||
            mean = sum / (double) count;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        public void printInto(PrintStream out) {
 | 
			
		||||
            out.printf("%.1f/%.1f/%.1f", min, (sum / (double) count), max);
 | 
			
		||||
            double min = (double) this.min / 10.0;
 | 
			
		||||
            double max = (double) this.max / 10.0;
 | 
			
		||||
            out.printf("%.1f/%.1f/%.1f", min, mean, max);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static class StationIdent implements Comparable<StationIdent> {
 | 
			
		||||
        private final int nameLength;
 | 
			
		||||
        private final String name;
 | 
			
		||||
    public static class StationIdent {
 | 
			
		||||
        private final byte[] name;
 | 
			
		||||
        private final int nameHash;
 | 
			
		||||
 | 
			
		||||
        public StationIdent(byte[] name, int nameHash) {
 | 
			
		||||
            this.nameLength = name.length;
 | 
			
		||||
            this.name = new String(name);
 | 
			
		||||
            this.name = name;
 | 
			
		||||
            // TODO: DEBUG
 | 
			
		||||
            // if(Arrays.asList(this.name).contains(';')) {
 | 
			
		||||
            // throw new RuntimeException();
 | 
			
		||||
            // }
 | 
			
		||||
            this.nameHash = nameHash;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -56,15 +77,10 @@ public class CalculateAverage_seijikun {
 | 
			
		||||
        @Override
 | 
			
		||||
        public boolean equals(Object obj) {
 | 
			
		||||
            var other = (StationIdent) obj;
 | 
			
		||||
            if (other.nameLength != nameLength) {
 | 
			
		||||
            if (other.name.length != name.length) {
 | 
			
		||||
                return false;
 | 
			
		||||
            }
 | 
			
		||||
            return name.equals(other.name);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public int compareTo(StationIdent o) {
 | 
			
		||||
            return name.compareTo(o.name);
 | 
			
		||||
            return Arrays.equals(name, other.name);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -77,9 +93,11 @@ public class CalculateAverage_seijikun {
 | 
			
		||||
        private final long endOffset;
 | 
			
		||||
 | 
			
		||||
        // state
 | 
			
		||||
        private int chunkSize = 0;
 | 
			
		||||
        private MappedByteBuffer buffer = null;
 | 
			
		||||
        private MemorySegment memorySegment = null;
 | 
			
		||||
        private int ptr = 0;
 | 
			
		||||
        private TreeMap<StationIdent, MeasurementAggregator> workSet;
 | 
			
		||||
        private HashMap<StationIdent, MeasurementAggregator> workSet;
 | 
			
		||||
 | 
			
		||||
        public ChunkReader(RandomAccessFile file, long startOffset, long endOffset) {
 | 
			
		||||
            this.file = file;
 | 
			
		||||
@@ -87,36 +105,67 @@ public class CalculateAverage_seijikun {
 | 
			
		||||
            this.endOffset = endOffset;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // private StationIdent readStationName() {
 | 
			
		||||
        // int startPtr = ptr;
 | 
			
		||||
        // int hashCode = 0;
 | 
			
		||||
        // int hashBytePtr = 0;
 | 
			
		||||
        // byte c;
 | 
			
		||||
        // while ((c = buffer.get(ptr++)) != ';') {
 | 
			
		||||
        // hashCode ^= ((int) c) << (hashBytePtr * 8);
 | 
			
		||||
        // hashBytePtr = (hashBytePtr + 1) % 4;
 | 
			
		||||
        // }
 | 
			
		||||
        // byte[] stationNameBfr = new byte[ptr - startPtr - 1];
 | 
			
		||||
        // buffer.get(startPtr, stationNameBfr);
 | 
			
		||||
        // return new StationIdent(stationNameBfr, hashCode);
 | 
			
		||||
        // }
 | 
			
		||||
 | 
			
		||||
        private StationIdent readStationName() {
 | 
			
		||||
            int startPtr = ptr;
 | 
			
		||||
            int hashCode = 0;
 | 
			
		||||
            int hashBytePtr = 0;
 | 
			
		||||
            byte c;
 | 
			
		||||
            while ((c = buffer.get(ptr++)) != ';') {
 | 
			
		||||
                hashCode ^= ((int) c) << (hashBytePtr * 8);
 | 
			
		||||
                hashBytePtr = (hashBytePtr + 1) % 4;
 | 
			
		||||
            final var VECTOR_SPECIES = ByteVector.SPECIES_256;
 | 
			
		||||
 | 
			
		||||
            if (chunkSize - ptr < VECTOR_SPECIES.length()) { // fallback
 | 
			
		||||
                int startPtr = ptr;
 | 
			
		||||
                while (buffer.get(ptr++) != ';') {
 | 
			
		||||
                }
 | 
			
		||||
                byte[] stationNameBfr = new byte[ptr - startPtr - 1];
 | 
			
		||||
                buffer.get(startPtr, stationNameBfr);
 | 
			
		||||
                return new StationIdent(stationNameBfr, Arrays.hashCode(stationNameBfr) ^ stationNameBfr.length);
 | 
			
		||||
            }
 | 
			
		||||
            else { // SIMD
 | 
			
		||||
                int sepIdx = 0;
 | 
			
		||||
 | 
			
		||||
                while (true) {
 | 
			
		||||
                    ByteVector tmp = ByteVector.fromMemorySegment(VECTOR_SPECIES, memorySegment, ptr + sepIdx, ByteOrder.LITTLE_ENDIAN);
 | 
			
		||||
                    final var cmpResult = tmp.compare(VectorOperators.EQ, ';');
 | 
			
		||||
                    if (cmpResult.anyTrue()) {
 | 
			
		||||
                        sepIdx += cmpResult.firstTrue();
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                    else {
 | 
			
		||||
                        sepIdx += tmp.length();
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                int endPtr = ptr + sepIdx;
 | 
			
		||||
                byte[] stationNameBfr = new byte[endPtr - ptr];
 | 
			
		||||
                buffer.get(ptr, stationNameBfr);
 | 
			
		||||
                ptr = endPtr + 1;
 | 
			
		||||
                return new StationIdent(stationNameBfr, Arrays.hashCode(stationNameBfr) ^ stationNameBfr.length);
 | 
			
		||||
            }
 | 
			
		||||
            byte[] stationNameBfr = new byte[ptr - startPtr - 1];
 | 
			
		||||
            buffer.get(startPtr, stationNameBfr);
 | 
			
		||||
            return new StationIdent(stationNameBfr, hashCode);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        private double readTemperature() {
 | 
			
		||||
            double ret = 0, div = 1;
 | 
			
		||||
        private int readTemperature() {
 | 
			
		||||
            int ret = 0;
 | 
			
		||||
            byte c = buffer.get(ptr++);
 | 
			
		||||
            boolean neg = (c == '-');
 | 
			
		||||
            if (neg)
 | 
			
		||||
            final boolean neg = (c == '-');
 | 
			
		||||
            if (neg) {
 | 
			
		||||
                c = buffer.get(ptr++);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            do {
 | 
			
		||||
                ret = ret * 10 + c - '0';
 | 
			
		||||
            } while ((c = buffer.get(ptr++)) >= '0' && c <= '9');
 | 
			
		||||
 | 
			
		||||
            if (c == '.') {
 | 
			
		||||
                while ((c = buffer.get(ptr++)) != '\n') {
 | 
			
		||||
                    ret += (c - '0') / (div *= 10);
 | 
			
		||||
                if (c != '.') {
 | 
			
		||||
                    ret = ret * 10 + c - '0';
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            } while ((c = buffer.get(ptr++)) != '\n');
 | 
			
		||||
 | 
			
		||||
            if (neg)
 | 
			
		||||
                return -ret;
 | 
			
		||||
@@ -125,14 +174,18 @@ public class CalculateAverage_seijikun {
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public void run() {
 | 
			
		||||
            workSet = new TreeMap<>();
 | 
			
		||||
            int chunkSize = (int) (endOffset - startOffset);
 | 
			
		||||
            workSet = new HashMap<>();
 | 
			
		||||
            if (endOffset - startOffset > Integer.MAX_VALUE) {
 | 
			
		||||
                throw new RuntimeException("Mapping a block larger than 2GB is not possible with Java! Welcome to 2024 :)");
 | 
			
		||||
            }
 | 
			
		||||
            chunkSize = (int) (endOffset - startOffset);
 | 
			
		||||
            try {
 | 
			
		||||
                buffer = file.getChannel().map(FileChannel.MapMode.READ_ONLY, startOffset, chunkSize);
 | 
			
		||||
                memorySegment = MemorySegment.ofBuffer(buffer);
 | 
			
		||||
 | 
			
		||||
                while (ptr < chunkSize) {
 | 
			
		||||
                    var station = readStationName();
 | 
			
		||||
                    var temp = readTemperature();
 | 
			
		||||
                    int temp = readTemperature();
 | 
			
		||||
                    var stationWorkSet = workSet.get(station);
 | 
			
		||||
                    if (stationWorkSet == null) {
 | 
			
		||||
                        stationWorkSet = new MeasurementAggregator();
 | 
			
		||||
@@ -144,26 +197,42 @@ public class CalculateAverage_seijikun {
 | 
			
		||||
                    stationWorkSet.count += 1;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            catch (IOException e) {
 | 
			
		||||
            catch (Throwable e) {
 | 
			
		||||
                e.printStackTrace();
 | 
			
		||||
                throw new RuntimeException(e);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static void main(String[] args) throws IOException, InterruptedException {
 | 
			
		||||
        RandomAccessFile file = new RandomAccessFile(FILE, "r");
 | 
			
		||||
    private static void printWorkSet(TreeMap<String, MeasurementAggregator> result, PrintStream out) {
 | 
			
		||||
        out.write('{');
 | 
			
		||||
        final var iterator = result.entrySet().iterator();
 | 
			
		||||
        while (iterator.hasNext()) {
 | 
			
		||||
            var entry = iterator.next();
 | 
			
		||||
            out.print(entry.getKey());
 | 
			
		||||
            out.write('=');
 | 
			
		||||
            entry.getValue().printInto(out);
 | 
			
		||||
            if (iterator.hasNext()) {
 | 
			
		||||
                out.print(", ");
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        out.println('}');
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
        int jobCnt = Runtime.getRuntime().availableProcessors();
 | 
			
		||||
    private static int createChunks(final RandomAccessFile file, final ChunkReader[] chunks) throws IOException {
 | 
			
		||||
        final long fileEndPtr = file.length();
 | 
			
		||||
        final long chunkSize = Math.max(1, fileEndPtr / chunks.length);
 | 
			
		||||
 | 
			
		||||
        var chunks = new ChunkReader[jobCnt];
 | 
			
		||||
        long chunkSize = file.length() / jobCnt;
 | 
			
		||||
        int jobCnt = 0;
 | 
			
		||||
        long chunkStartPtr = 0;
 | 
			
		||||
        byte[] tmpBuffer = new byte[128];
 | 
			
		||||
        for (int i = 0; i < jobCnt; ++i) {
 | 
			
		||||
            long chunkEndPtr = chunkStartPtr + chunkSize;
 | 
			
		||||
            if (i != (jobCnt - 1)) { // align chunks to newlines
 | 
			
		||||
                file.seek(chunkEndPtr - 1);
 | 
			
		||||
        final byte[] tmpBuffer = new byte[128];
 | 
			
		||||
        while (chunkStartPtr < fileEndPtr) {
 | 
			
		||||
            long chunkEndPtr = Math.min(chunkStartPtr + chunkSize, fileEndPtr);
 | 
			
		||||
 | 
			
		||||
            // Seek into file at the calculated chunk end ptr, then extend it until the next
 | 
			
		||||
            // new-line or EOF
 | 
			
		||||
            if (chunkEndPtr < fileEndPtr) {
 | 
			
		||||
                file.seek(Math.max(0, chunkEndPtr - 1));
 | 
			
		||||
                file.read(tmpBuffer);
 | 
			
		||||
                int offset = 0;
 | 
			
		||||
                while (tmpBuffer[offset] != '\n') {
 | 
			
		||||
@@ -171,28 +240,38 @@ public class CalculateAverage_seijikun {
 | 
			
		||||
                }
 | 
			
		||||
                chunkEndPtr += offset;
 | 
			
		||||
            }
 | 
			
		||||
            else { // last chunk ends at file end
 | 
			
		||||
                chunkEndPtr = file.length();
 | 
			
		||||
            }
 | 
			
		||||
            chunks[i] = new ChunkReader(file, chunkStartPtr, chunkEndPtr);
 | 
			
		||||
 | 
			
		||||
            chunks[jobCnt] = new ChunkReader(file, chunkStartPtr, chunkEndPtr);
 | 
			
		||||
            jobCnt += 1;
 | 
			
		||||
            chunkStartPtr = chunkEndPtr;
 | 
			
		||||
        }
 | 
			
		||||
        return jobCnt;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
        try (var executor = Executors.newFixedThreadPool(jobCnt)) {
 | 
			
		||||
    public static void main(String[] args) throws IOException, InterruptedException {
 | 
			
		||||
        final RandomAccessFile file = new RandomAccessFile(FILE, "r");
 | 
			
		||||
 | 
			
		||||
        int jobCnt = Runtime.getRuntime().availableProcessors();
 | 
			
		||||
 | 
			
		||||
        final var chunks = new ChunkReader[jobCnt];
 | 
			
		||||
        jobCnt = createChunks(file, chunks);
 | 
			
		||||
 | 
			
		||||
        try (final var executor = Executors.newFixedThreadPool(jobCnt)) {
 | 
			
		||||
            for (int i = 0; i < jobCnt; ++i) {
 | 
			
		||||
                executor.submit(chunks[i]);
 | 
			
		||||
            }
 | 
			
		||||
            executor.shutdown();
 | 
			
		||||
            var ignored = executor.awaitTermination(1, TimeUnit.DAYS);
 | 
			
		||||
            final var ignored = executor.awaitTermination(1, TimeUnit.DAYS);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // merge chunks
 | 
			
		||||
        var result = chunks[0].workSet;
 | 
			
		||||
        for (int i = 1; i < jobCnt; ++i) {
 | 
			
		||||
        final var result = new TreeMap<String, MeasurementAggregator>();
 | 
			
		||||
        for (int i = 0; i < jobCnt; ++i) {
 | 
			
		||||
            chunks[i].workSet.forEach((ident, otherStationWorkSet) -> {
 | 
			
		||||
                var stationWorkSet = result.get(ident);
 | 
			
		||||
                final var identStr = new String(ident.name);
 | 
			
		||||
                final var stationWorkSet = result.get(identStr);
 | 
			
		||||
                if (stationWorkSet == null) {
 | 
			
		||||
                    result.put(ident, otherStationWorkSet);
 | 
			
		||||
                    result.put(identStr, otherStationWorkSet);
 | 
			
		||||
                }
 | 
			
		||||
                else {
 | 
			
		||||
                    stationWorkSet.min = Math.min(stationWorkSet.min, otherStationWorkSet.min);
 | 
			
		||||
@@ -202,19 +281,9 @@ public class CalculateAverage_seijikun {
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
        result.forEach((ignored, meas) -> meas.finish());
 | 
			
		||||
 | 
			
		||||
        // print in required format
 | 
			
		||||
        System.out.write('{');
 | 
			
		||||
        var iterator = result.entrySet().iterator();
 | 
			
		||||
        while (iterator.hasNext()) {
 | 
			
		||||
            var entry = iterator.next();
 | 
			
		||||
            System.out.print(entry.getKey().name);
 | 
			
		||||
            System.out.write('=');
 | 
			
		||||
            entry.getValue().printInto(System.out);
 | 
			
		||||
            if (iterator.hasNext()) {
 | 
			
		||||
                System.out.print(", ");
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        System.out.println('}');
 | 
			
		||||
        printWorkSet(result, System.out);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user