From 8bae1b87810f75ddf307ba0b84400d97e3e6f851 Mon Sep 17 00:00:00 2001 From: Dr Ian Preston <157221403+ianopolousfast@users.noreply.github.com> Date: Tue, 23 Jan 2024 15:37:33 +0000 Subject: [PATCH] Use simd for name comparison (#568) Co-authored-by: Ian Preston --- .../CalculateAverage_ianopolousfast.java | 119 +++++------------- 1 file changed, 32 insertions(+), 87 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java index 8944a47..f1b4e7b 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java @@ -34,7 +34,7 @@ import static java.lang.foreign.ValueLayout.*; /* A fast implementation with no unsafe. * Features: * * memory mapped file using preview Arena FFI - * * semicolon finding using incubator vector api + * * semicolon finding and name comparison using incubator vector api * * read chunks in parallel * * minimise allocation * * no unsafe @@ -80,12 +80,11 @@ public class CalculateAverage_ianopolousfast { System.out.println(merged); } - public static boolean matchingStationBytes(long start, long end, int offset, MemorySegment buffer, Stat existing) { - int len = (int) (end - start); - if (len != existing.name.length) - return false; - for (int i = offset; i < len; i++) { - if (existing.name[i] != buffer.get(JAVA_BYTE, offset + start++)) + public static boolean matchingStationBytes(long start, long end, MemorySegment buffer, Stat existing) { + for (int index = 0; index < end - start; index += BYTE_SPECIES.vectorByteSize()) { + ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, start + index, ByteOrder.nativeOrder(), BYTE_SPECIES.indexInRange(start + index, end)); + ByteVector found = ByteVector.fromArray(BYTE_SPECIES, existing.name, index); + if (!found.eq(line).allTrue()) return false; } return true; @@ -98,21 +97,19 @@ public class CalculateAverage_ianopolousfast { return (finalHash & (len - 1)); } - public static Stat parseStation(long start, long end, long first8, long second8, - MemorySegment buffer) { + public static Stat createStation(long start, long end, MemorySegment buffer) { byte[] stationBuffer = new byte[(int) (end - start)]; for (long off = start; off < end; off++) stationBuffer[(int) (off - start)] = buffer.get(JAVA_BYTE, off); - return new Stat(stationBuffer, first8, second8); + return new Stat(stationBuffer); } - public static Stat dedupeStation(long start, long end, long hash, long first8, long second8, - MemorySegment buffer, List> stations) { + public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, List> stations) { int index = hashToIndex(hash, MAX_STATIONS); List matches = stations.get(index); if (matches == null) { List value = new ArrayList<>(); - Stat res = parseStation(start, end, first8, second8, buffer); + Stat res = createStation(start, end, buffer); value.add(res); stations.set(index, value); return res; @@ -120,54 +117,10 @@ public class CalculateAverage_ianopolousfast { else { for (int i = 0; i < matches.size(); i++) { Stat s = matches.get(i); - if (first8 == s.first8 && second8 == s.second8 && matchingStationBytes(start, end, 16, buffer, s)) + if (matchingStationBytes(start, end, buffer, s)) return s; } - Stat res = parseStation(start, end, first8, second8, buffer); - matches.add(res); - return res; - } - } - - public static Stat dedupeStation8(long start, long end, long hash, long first8, MemorySegment buffer, List> stations) { - int index = hashToIndex(hash, MAX_STATIONS); - List matches = stations.get(index); - if (matches == null) { - List value = new ArrayList<>(); - Stat station = parseStation(start, end, first8, 0, buffer); - value.add(station); - stations.set(index, value); - return station; - } - else { - for (int i = 0; i < matches.size(); i++) { - Stat s = matches.get(i); - if (first8 == s.first8 && s.name.length <= 8) - return s; - } - Stat station = parseStation(start, end, first8, 0, buffer); - matches.add(station); - return station; - } - } - - public static Stat dedupeStation16(long start, long end, long hash, long first8, long second8, MemorySegment buffer, List> stations) { - int index = hashToIndex(hash, MAX_STATIONS); - List matches = stations.get(index); - if (matches == null) { - List value = new ArrayList<>(); - Stat res = parseStation(start, end, first8, second8, buffer); - value.add(res); - stations.set(index, value); - return res; - } - else { - for (int i = 0; i < matches.size(); i++) { - Stat s = matches.get(i); - if (first8 == s.first8 && second8 == s.second8 && s.name.length <= 16) - return s; - } - Stat res = parseStation(start, end, first8, second8, buffer); + Stat res = createStation(start, end, buffer); matches.add(res); return res; } @@ -181,32 +134,22 @@ public class CalculateAverage_ianopolousfast { ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder()); int keySize = line.compare(VectorOperators.EQ, ';').firstTrue(); + long first8 = buffer.get(LONG_LAYOUT, lineStart); if (keySize == BYTE_SPECIES.vectorByteSize()) { while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') { keySize++; } - long first8 = buffer.get(LONG_LAYOUT, lineStart); - if (keySize < 8) - return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations); long second8 = buffer.get(LONG_LAYOUT, lineStart + 8); - if (keySize < 16) - return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations); long hash = first8 ^ second8; // todo include other bytes - return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations); + return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations); } - long first8 = buffer.get(LONG_LAYOUT, lineStart); if (keySize <= 8) { first8 = maskHighBytes(first8, keySize & 0x07); - return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations); - } - long second8 = buffer.get(LONG_LAYOUT, lineStart + 8); - if (keySize < 16) { - second8 = maskHighBytes(second8, keySize & 0x07); - return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations); } + long second8 = keySize <= 8 ? 0 : maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); long hash = first8 ^ second8; // todo include later bytes - return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations); + return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations); } public static int getDot(long d) { @@ -261,13 +204,10 @@ public class CalculateAverage_ianopolousfast { // in the inner loop (reducing branches) // We need at least the vector lane size bytes back if (endByte == buffer.byteSize()) { - endByte -= 1; // skip final new line // reverse at least vector lane width - while (endByte > 0 && buffer.byteSize() - endByte < BYTE_SPECIES.vectorByteSize()) { + endByte = Math.max(buffer.byteSize() - BYTE_SPECIES.vectorByteSize(), 0); + while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n') endByte--; - while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n') - endByte--; - } if (endByte > 0) endByte++; @@ -278,28 +218,33 @@ public class CalculateAverage_ianopolousfast { int index = 0; while (endByte + index < buffer.byteSize()) { Stat station = parseStation(index, end, stations); - index = (int) processTemperature(index + station.name.length + 1, end, station); + index = (int) processTemperature(index + station.namelen + 1, end, station); } } + innerloop(startByte, endByte, buffer, stations); + return stations; + } + + private static void innerloop(long startByte, long endByte, MemorySegment buffer, List> stations) { while (startByte < endByte) { Stat station = parseStation(startByte, buffer, stations); - startByte = processTemperature(startByte + station.name.length + 1, buffer, station); + startByte = processTemperature(startByte + station.namelen + 1, buffer, station); } - return stations; } public static class Stat { final byte[] name; + final int namelen; int count = 0; short min = Short.MAX_VALUE, max = Short.MIN_VALUE; long total = 0; - final long first8, second8; - public Stat(byte[] name, long first8, long second8) { - this.name = name; - this.first8 = first8; - this.second8 = second8; + public Stat(byte[] name) { + int vecSize = BYTE_SPECIES.vectorByteSize(); + int arrayLen = (name.length + vecSize - 1) / vecSize * vecSize; + this.name = Arrays.copyOfRange(name, 0, arrayLen); + this.namelen = name.length; } public void add(short value) { @@ -326,7 +271,7 @@ public class CalculateAverage_ianopolousfast { } public String name() { - return new String(name); + return new String(Arrays.copyOfRange(name, 0, namelen)); } public String toString() {