From 0bd167557183922825a0a9ec3a1d347f4ba24b2f Mon Sep 17 00:00:00 2001 From: Dr Ian Preston <157221403+ianopolousfast@users.noreply.github.com> Date: Thu, 25 Jan 2024 22:03:05 +0000 Subject: [PATCH] Down to 14s locally (#583) Use flat array for stats. Use simd for line termination Co-authored-by: Ian Preston --- .../CalculateAverage_ianopolousfast.java | 120 ++++++++++-------- 1 file changed, 70 insertions(+), 50 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java index f1b4e7b..ab960df 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java @@ -41,7 +41,7 @@ import static java.lang.foreign.ValueLayout.*; * * Timings on 4 core i7-7500U CPU @ 2.70GHz: * average_baseline: 4m48s - * ianopolous: 15s + * ianopolous: 14s */ public class CalculateAverage_ianopolousfast { @@ -60,7 +60,7 @@ public class CalculateAverage_ianopolousfast { MemorySegment mmap = channel.map(FileChannel.MapMode.READ_ONLY, 0, filesize, arena); int nChunks = filesize < 4 * 1024 * 1024 ? 1 : Runtime.getRuntime().availableProcessors(); long chunkSize = (filesize + nChunks - 1) / nChunks; - List>> allResults = IntStream.range(0, nChunks) + List allResults = IntStream.range(0, nChunks) .parallel() .mapToObj(i -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize), mmap)) .toList(); @@ -69,7 +69,7 @@ public class CalculateAverage_ianopolousfast { .parallel() .flatMap(f -> { try { - return f.stream().filter(Objects::nonNull).flatMap(Collection::stream); + return Arrays.stream(f).filter(Objects::nonNull); } catch (Exception e) { e.printStackTrace(); @@ -104,24 +104,23 @@ public class CalculateAverage_ianopolousfast { return new Stat(stationBuffer); } - public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, List> stations) { + public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, Stat[] stations) { int index = hashToIndex(hash, MAX_STATIONS); - List matches = stations.get(index); - if (matches == null) { - List value = new ArrayList<>(); + Stat match = stations[index]; + if (match == null) { Stat res = createStation(start, end, buffer); - value.add(res); - stations.set(index, value); + stations[index] = res; return res; } else { - for (int i = 0; i < matches.size(); i++) { - Stat s = matches.get(i); - if (matchingStationBytes(start, end, buffer, s)) - return s; + while (match != null) { + if (matchingStationBytes(start, end, buffer, match)) + return match; + index = (index + 1) % stations.length; + match = stations[index]; } Stat res = createStation(start, end, buffer); - matches.add(res); + stations[index] = res; return res; } } @@ -130,50 +129,38 @@ public class CalculateAverage_ianopolousfast { return d & (-1L << ((8 - nbytes) * 8)); } - public static Stat parseStation(long lineStart, MemorySegment buffer, List> stations) { + public static Stat parseStation(long lineStart, MemorySegment buffer, Stat[] stations) { ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder()); int keySize = line.compare(VectorOperators.EQ, ';').firstTrue(); long first8 = buffer.get(LONG_LAYOUT, lineStart); - if (keySize == BYTE_SPECIES.vectorByteSize()) { - while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') { - keySize++; - } - long second8 = buffer.get(LONG_LAYOUT, lineStart + 8); - long hash = first8 ^ second8; // todo include other bytes - return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations); - } - + long second8 = 0; if (keySize <= 8) { first8 = maskHighBytes(first8, keySize & 0x07); } - long second8 = keySize <= 8 ? 0 : maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); + else if (keySize <= 16) { + second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); + } + else if (keySize == BYTE_SPECIES.vectorByteSize()) { + while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') { + keySize++; + } + second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); + } long hash = first8 ^ second8; // todo include later bytes return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations); } - public static int getDot(long d) { - // from Hacker's Delight page 92 - d = d ^ 0x2e2e2e2e2e2e2e2eL; - long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL; - y = ~(y | d | 0x7f7f7f7f7f7f7f7fL); - return Long.numberOfLeadingZeros(y) >> 3; - } - public static short getMinus(long d) { - d = d & 0xff00000000000000L; - d = d ^ 0x2d2d2d2d2d2d2d2dL; - long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL; - y = ~(y | d | 0x7f7f7f7f7f7f7f7fL); - return (short) ((Long.numberOfLeadingZeros(y) >> 6) - 1); + return ((d & 0xff00000000000000L) ^ 0x2d00000000000000L) != 0 ? 0 : (short) -1; } - public static long processTemperature(long lineSplit, MemorySegment buffer, Stat station) { + public static long processTemperature(long lineSplit, int size, MemorySegment buffer, Stat station) { long d = buffer.get(LONG_LAYOUT, lineSplit); // negative is either 0 or -1 short negative = getMinus(d); d = d << (negative * -8); - int dotIndex = getDot(d); + int dotIndex = size - 2 + negative; d = (d >> 8) | 0x30000000_00000000L; // add a leading 0 digit d = d >> 8 * (5 - dotIndex); short temperature = (short) ((byte) d - '0' + @@ -181,10 +168,41 @@ public class CalculateAverage_ianopolousfast { 100 * (((byte) (d >> 24)) - '0')); temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty station.add(temperature); - return lineSplit - negative + dotIndex + 3; + return lineSplit + size + 1; } - public static List> parseStats(long startByte, long endByte, MemorySegment buffer) { + private static long parseLine(long lineStart, MemorySegment buffer, Stat[] stations) { + ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder()); + int lineSize = line.compare(VectorOperators.EQ, '\n').firstTrue(); + int index = lineSize; + while (index == BYTE_SPECIES.vectorByteSize()) { + index = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize, + ByteOrder.nativeOrder()).compare(VectorOperators.EQ, '\n').firstTrue(); + lineSize += index; + } + int keySize = lineSize - 6 + ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize - 6, + ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue(); + + long first8 = buffer.get(LONG_LAYOUT, lineStart); + long second8 = 0; + if (keySize <= 8) { + first8 = maskHighBytes(first8, keySize & 0x07); + } + else if (keySize <= 16) { + second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); + } + else if (keySize == BYTE_SPECIES.vectorByteSize()) { + while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') { + keySize++; + } + second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); + } + long hash = first8 ^ second8; // todo include later bytes + Stat station = dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations); + return processTemperature(lineStart + keySize + 1, lineSize - keySize - 1, buffer, station); + } + + public static Stat[] parseStats(long startByte, long endByte, MemorySegment buffer) { // read first partial line if (startByte > 0) { for (int i = 0; i < MAX_LINE_LENGTH; i++) { @@ -195,9 +213,7 @@ public class CalculateAverage_ianopolousfast { } } - List> stations = new ArrayList<>(MAX_STATIONS); - for (int i = 0; i < MAX_STATIONS; i++) - stations.add(null); + Stat[] stations = new Stat[MAX_STATIONS]; // Handle reading the very last few lines in the file // this allows us to not worry about reading beyond the end @@ -218,7 +234,12 @@ public class CalculateAverage_ianopolousfast { int index = 0; while (endByte + index < buffer.byteSize()) { Stat station = parseStation(index, end, stations); - index = (int) processTemperature(index + station.namelen + 1, end, station); + int tempSize = 3; + if (end.get(JAVA_BYTE, index + station.namelen + 5) == '\n') + tempSize = 4; + if (end.get(JAVA_BYTE, index + station.namelen + 6) == '\n') + tempSize = 5; + index = (int) processTemperature(index + station.namelen + 1, tempSize, end, station); } } @@ -226,10 +247,9 @@ public class CalculateAverage_ianopolousfast { return stations; } - private static void innerloop(long startByte, long endByte, MemorySegment buffer, List> stations) { + private static void innerloop(long startByte, long endByte, MemorySegment buffer, Stat[] stations) { while (startByte < endByte) { - Stat station = parseStation(startByte, buffer, stations); - startByte = processTemperature(startByte + station.namelen + 1, buffer, station); + startByte = parseLine(startByte, buffer, stations); } } @@ -278,4 +298,4 @@ public class CalculateAverage_ianopolousfast { return round((double) min) + "/" + round(((double) total) / count) + "/" + round((double) max); } } -} \ No newline at end of file +}