diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java index a8c4e4c..4bffe78 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java @@ -15,45 +15,53 @@ */ package dev.morling.onebrc; -import java.io.*; -import java.nio.*; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.ByteOrder; import java.nio.channels.*; -import java.util.concurrent.*; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; import java.util.stream.*; import java.util.*; +import static java.lang.foreign.ValueLayout.*; + /* A fast implementation with no unsafe. * Features: - * * memory mapped file + * * memory mapped file using preview Arena FFI * * read chunks in parallel * * minimise allocation * * no unsafe * * Timings on 4 core i7-7500U CPU @ 2.70GHz: * average_baseline: 4m48s - * ianopolous: 19s + * ianopolous: 16s */ public class CalculateAverage_ianopolousfast { public static final int MAX_LINE_LENGTH = 107; - public static final int MAX_STATIONS = 10_000; + public static final int MAX_STATIONS = 1 << 14; + private static final OfLong LONG_LAYOUT = JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN); public static void main(String[] args) throws Exception { - File input = new File("./measurements.txt"); - long filesize = input.length(); - // keep chunk size between 256 MB and 1G (1 chunk for files < 256MB) - long chunkSize = Math.min(Math.max((filesize + 31) / 32, 256 * 1024 * 1024), 1024 * 1024 * 1024L); - int nChunks = (int) ((filesize + chunkSize - 1) / chunkSize); - ExecutorService pool = Executors.newVirtualThreadPerTaskExecutor(); - List>>> allResults = IntStream.range(0, nChunks) - .mapToObj(i -> pool.submit(() -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize)))) + Arena arena = Arena.global(); + Path input = Path.of("measurements.txt"); + FileChannel channel = (FileChannel) Files.newByteChannel(input, StandardOpenOption.READ); + long filesize = Files.size(input); + MemorySegment mmap = channel.map(FileChannel.MapMode.READ_ONLY, 0, filesize, arena); + int nChunks = filesize < 4 * 1024 * 1024 ? 1 : Runtime.getRuntime().availableProcessors(); + long chunkSize = (filesize + nChunks - 1) / nChunks; + List>> allResults = IntStream.range(0, nChunks) + .parallel() + .mapToObj(i -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize), mmap)) .toList(); TreeMap merged = allResults.stream() .parallel() .flatMap(f -> { try { - return f.get().stream().filter(Objects::nonNull).flatMap(Collection::stream); + return f.stream().filter(Objects::nonNull).flatMap(Collection::stream); } catch (Exception e) { e.printStackTrace(); @@ -64,25 +72,39 @@ public class CalculateAverage_ianopolousfast { System.out.println(merged); } - public static boolean matchingStationBytes(int start, int end, ByteBuffer buffer, Stat existing) { - if (end - start != existing.name.length) + public static boolean matchingStationBytes(long start, long end, int offset, MemorySegment buffer, Stat existing) { + int len = (int) (end - start); + if (len != existing.name.length) return false; - for (int i = start; i < end; i++) { - if (existing.name[i - start] != buffer.get(i)) + for (int i = offset; i < len; i++) { + if (existing.name[i] != buffer.get(JAVA_BYTE, offset + start++)) return false; } return true; } - public static Stat dedupeStation(int start, int end, long hash, ByteBuffer buffer, List> stations) { - int index = Math.floorMod(hash ^ (hash >> 32), MAX_STATIONS); + private static int hashToIndex(long hash, int len) { + // From Thomas Wuerthinger's entry + int hashAsInt = (int) (hash ^ (hash >>> 28)); + int finalHash = (hashAsInt ^ (hashAsInt >>> 15)); + return (finalHash & (len - 1)); + } + + public static Stat parseStation(long start, long end, long first8, long second8, + MemorySegment buffer) { + byte[] stationBuffer = new byte[(int) (end - start)]; + for (long off = start; off < end; off++) + stationBuffer[(int) (off - start)] = buffer.get(JAVA_BYTE, off); + return new Stat(stationBuffer, first8, second8); + } + + public static Stat dedupeStation(long start, long end, long hash, long first8, long second8, + MemorySegment buffer, List> stations) { + int index = hashToIndex(hash, MAX_STATIONS); List matches = stations.get(index); if (matches == null) { List value = new ArrayList<>(); - byte[] stationBuffer = new byte[end - start]; - buffer.position(start); - buffer.get(stationBuffer); - Stat res = new Stat(stationBuffer); + Stat res = parseStation(start, end, first8, second8, buffer); value.add(res); stations.set(index, value); return res; @@ -90,136 +112,185 @@ public class CalculateAverage_ianopolousfast { else { for (int i = 0; i < matches.size(); i++) { Stat s = matches.get(i); - if (matchingStationBytes(start, end, buffer, s)) + if (first8 == s.first8 && second8 == s.second8 && matchingStationBytes(start, end, 16, buffer, s)) return s; } - byte[] stationBuffer = new byte[end - start]; - buffer.position(start); - buffer.get(stationBuffer); - Stat res = new Stat(stationBuffer); + Stat res = parseStation(start, end, first8, second8, buffer); matches.add(res); return res; } } - public static int getSemicolon(long d) { + public static Stat dedupeStation8(long start, long end, long hash, long first8, MemorySegment buffer, List> stations) { + int index = hashToIndex(hash, MAX_STATIONS); + List matches = stations.get(index); + if (matches == null) { + List value = new ArrayList<>(); + Stat station = parseStation(start, end, first8, 0, buffer); + value.add(station); + stations.set(index, value); + return station; + } + else { + for (int i = 0; i < matches.size(); i++) { + Stat s = matches.get(i); + if (first8 == s.first8 && s.name.length <= 8) + return s; + } + Stat station = parseStation(start, end, first8, 0, buffer); + matches.add(station); + return station; + } + } + + public static Stat dedupeStation16(long start, long end, long hash, long first8, long second8, MemorySegment buffer, List> stations) { + int index = hashToIndex(hash, MAX_STATIONS); + List matches = stations.get(index); + if (matches == null) { + List value = new ArrayList<>(); + Stat res = parseStation(start, end, first8, second8, buffer); + value.add(res); + stations.set(index, value); + return res; + } + else { + for (int i = 0; i < matches.size(); i++) { + Stat s = matches.get(i); + if (first8 == s.first8 && second8 == s.second8 && s.name.length <= 16) + return s; + } + Stat res = parseStation(start, end, first8, second8, buffer); + matches.add(res); + return res; + } + } + + public static long hasSemicolon(long d) { // from Hacker's Delight page 92 d = d ^ 0x3b3b3b3b3b3b3b3bL; long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL; + return ~(y | d | 0x7f7f7f7f7f7f7f7fL); + } + + public static int getSemicolonIndex(long y) { + // from Hacker's Delight page 92 + return Long.numberOfLeadingZeros(y) >> 3; + } + + static long maskHighBytes(long d, int nbytes) { + return d & (-1L << ((8 - nbytes) * 8)); + } + + public static Stat parseStation(long lineStart, MemorySegment buffer, List> stations) { + // find semicolon and update hash as we go, reading a long at a time + long d = buffer.get(LONG_LAYOUT, lineStart); + long hasSemi = hasSemicolon(d); + if (hasSemi != 0) { + int semiIndex = getSemicolonIndex(hasSemi); + d = maskHighBytes(d, semiIndex); + return dedupeStation8(lineStart, lineStart + semiIndex, d, d, buffer, stations); + } + long first8 = d; + long hash = d; + + d = buffer.get(LONG_LAYOUT, lineStart + 8); + hasSemi = hasSemicolon(d); + if (hasSemi != 0) { + int semiIndex = getSemicolonIndex(hasSemi); + if (semiIndex == 0) + return dedupeStation8(lineStart, lineStart + 8, first8, first8, buffer, stations); + d = maskHighBytes(d, semiIndex); + return dedupeStation16(lineStart, lineStart + 8 + semiIndex, first8 ^ d, first8, d, buffer, stations); + } + + int index = 8; + long second8 = d; + while (hasSemi == 0) { + hash = hash ^ d; + index += 8; + d = buffer.get(LONG_LAYOUT, lineStart + index); + hasSemi = hasSemicolon(d); + } + int semiIndex = getSemicolonIndex(hasSemi); + d = maskHighBytes(d, semiIndex); + if (semiIndex > 0) { + hash = hash ^ d; + } + return dedupeStation(lineStart, lineStart + index + semiIndex, hash, first8, second8, buffer, stations); + } + + public static int getDot(long d) { + // from Hacker's Delight page 92 + d = d ^ 0x2e2e2e2e2e2e2e2eL; + long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL; y = ~(y | d | 0x7f7f7f7f7f7f7f7fL); return Long.numberOfLeadingZeros(y) >> 3; } - public static long updateHash(long hash, long x) { - return ((hash << 5) ^ x) * 0x517cc1b727220a95L; // fxHash + public static short getMinus(long d) { + d = d & 0xff00000000000000L; + d = d ^ 0x2d2d2d2d2d2d2d2dL; + long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL; + y = ~(y | d | 0x7f7f7f7f7f7f7f7fL); + return (short) ((Long.numberOfLeadingZeros(y) >> 6) - 1); } - public static Stat parseStation(int lineStart, ByteBuffer buffer, List> stations) { - // find semicolon and update hash as we go, reading a long at a time - long d = buffer.getLong(lineStart); - - int semiIndex = getSemicolon(d); - int index = 0; - long hash = 0; - while (semiIndex == 8) { - hash = updateHash(hash, d); - index += 8; - d = buffer.getLong(lineStart + index); - semiIndex = getSemicolon(d); - } - // mask extra bytes off last long - d = d & (-1L << ((8 - semiIndex) * 8)); - if (semiIndex > 0) { - hash = updateHash(hash, d); - } - return dedupeStation(lineStart, lineStart + index + semiIndex, hash, buffer, stations); - } - - public static int processTemperature(int lineSplit, MappedByteBuffer buffer, Stat station) { - short temperature; - boolean negative = false; - byte b = buffer.get(lineSplit++); - if (b == '-') { - negative = true; - b = buffer.get(lineSplit++); - } - temperature = (short) (b - 0x30); - b = buffer.get(lineSplit++); - if (b == '.') { - b = buffer.get(lineSplit++); - temperature = (short) (temperature * 10 + (b - 0x30)); - } - else { - temperature = (short) (temperature * 10 + (b - 0x30)); - lineSplit++; - b = buffer.get(lineSplit++); - temperature = (short) (temperature * 10 + (b - 0x30)); - } - temperature = negative ? (short) -temperature : temperature; + public static long processTemperature(long lineSplit, MemorySegment buffer, Stat station) { + long d = buffer.get(LONG_LAYOUT, lineSplit); + // negative is either 0 or -1 + short negative = getMinus(d); + d = d << (negative * -8); + int dotIndex = getDot(d); + d = (d >> 8) | 0x30000000_00000000L; // add a leading 0 digit + d = d >> 8 * (5 - dotIndex); + short temperature = (short) ((byte) d - '0' + + 10 * (((byte) (d >> 16)) - '0') + + 100 * (((byte) (d >> 24)) - '0')); + temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty station.add(temperature); - return lineSplit + 1; + return lineSplit - negative + dotIndex + 3; } - public static List> parseStats(long startByte, long endByte) { - try { - RandomAccessFile file = new RandomAccessFile("./measurements.txt", "r"); - long maxEnd = Math.min(file.length(), endByte + MAX_LINE_LENGTH); - long len = maxEnd - startByte; - if (len > Integer.MAX_VALUE) - throw new RuntimeException("Segment size must fit into an int"); - int maxDone = (int) (endByte - startByte); - MappedByteBuffer buffer = file.getChannel().map(FileChannel.MapMode.READ_ONLY, startByte, len); - int done = 0; - // read first partial line - if (startByte > 0) { - for (int i = 0; i < MAX_LINE_LENGTH; i++) { - byte b = buffer.get(i); - if (b == '\n') { - done = i + 1; - break; - } + public static List> parseStats(long startByte, long endByte, MemorySegment buffer) { + // read first partial line + if (startByte > 0) { + for (int i = 0; i < MAX_LINE_LENGTH; i++) { + byte b = buffer.get(JAVA_BYTE, startByte++); + if (b == '\n') { + break; } } - - List> stations = new ArrayList<>(MAX_STATIONS); - for (int i = 0; i < MAX_STATIONS; i++) - stations.add(null); - - // Handle reading the very last line in the file - // this allows us to not worry about reading a long beyond the end - // in the inner loop (reducing branches) - // We only need to read one because the min record size is 6 bytes - // so 2nd last record must be > 8 from end - if (endByte == file.length()) { - int offset = (int) (file.length() - startByte - 1); - while (buffer.get(offset) != '\n') // final new line - offset--; - offset--; - while (offset > 0 && buffer.get(offset) != '\n') // end of second last line - offset--; - maxDone = offset; - if (offset > 0) - offset++; - // copy into a 8n sized buffer to avoid reading off end - int roundedSize = (int) (file.length() - startByte) - offset; - roundedSize = (roundedSize + 7) / 8 * 8; - byte[] end = new byte[roundedSize]; - for (int i = offset; i < (int) (file.length() - startByte); i++) - end[i - offset] = buffer.get(i); - Stat station = parseStation(0, ByteBuffer.wrap(end), stations); - processTemperature(offset + station.name.length + 1, buffer, station); - } - - int lineStart = done; - while (lineStart < maxDone) { - Stat station = parseStation(lineStart, buffer, stations); - lineStart = processTemperature(lineStart + station.name.length + 1, buffer, station); - } - return stations; } - catch (IOException e) { - throw new RuntimeException(e); + + List> stations = new ArrayList<>(MAX_STATIONS); + for (int i = 0; i < MAX_STATIONS; i++) + stations.add(null); + + // Handle reading the very last line in the file + // this allows us to not worry about reading a long beyond the end + // in the inner loop (reducing branches) + // We only need to read one because the min record size is 6 bytes + // so 2nd last record must be > 8 from end + if (endByte == buffer.byteSize()) { + endByte -= 2; // skip final new line + while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n') + endByte--; + + if (endByte > 0) + endByte++; + // copy into a 8n sized buffer to avoid reading off end + MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 4); + for (long i = endByte; i < buffer.byteSize(); i++) + end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i)); + Stat station = parseStation(0, end, stations); + processTemperature(station.name.length + 1, end, station); } + + while (startByte < endByte) { + Stat station = parseStation(startByte, buffer, stations); + startByte = processTemperature(startByte + station.name.length + 1, buffer, station); + } + return stations; } public static class Stat { @@ -227,9 +298,12 @@ public class CalculateAverage_ianopolousfast { int count = 0; short min = Short.MAX_VALUE, max = Short.MIN_VALUE; long total = 0; + final long first8, second8; - public Stat(byte[] name) { + public Stat(byte[] name, long first8, long second8) { this.name = name; + this.first8 = first8; + this.second8 = second8; } public void add(short value) { @@ -263,4 +337,4 @@ public class CalculateAverage_ianopolousfast { return round((double) min) + "/" + round(((double) total) / count) + "/" + round((double) max); } } -} +} \ No newline at end of file