From 062f2bbecf586d85ff44dec42cc63f94e49bc6b8 Mon Sep 17 00:00:00 2001 From: Dr Ian Preston <157221403+ianopolousfast@users.noreply.github.com> Date: Sat, 20 Jan 2024 19:09:40 +0000 Subject: [PATCH] Introducing the vector api. 1s faster on 4 core i7 (#506) Co-authored-by: Ian Preston --- calculate_average_ianopolousfast.sh | 2 +- .../CalculateAverage_ianopolousfast.java | 102 +++++++++--------- 2 files changed, 50 insertions(+), 54 deletions(-) diff --git a/calculate_average_ianopolousfast.sh b/calculate_average_ianopolousfast.sh index e5c0977..06c31d9 100755 --- a/calculate_average_ianopolousfast.sh +++ b/calculate_average_ianopolousfast.sh @@ -15,5 +15,5 @@ # limitations under the License. # -JAVA_OPTS="--enable-preview" +JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ianopolousfast diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java index 4bffe78..8944a47 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java @@ -15,6 +15,10 @@ */ package dev.morling.onebrc; +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; + import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.nio.ByteOrder; @@ -30,19 +34,23 @@ import static java.lang.foreign.ValueLayout.*; /* A fast implementation with no unsafe. * Features: * * memory mapped file using preview Arena FFI + * * semicolon finding using incubator vector api * * read chunks in parallel * * minimise allocation * * no unsafe * * Timings on 4 core i7-7500U CPU @ 2.70GHz: * average_baseline: 4m48s - * ianopolous: 16s + * ianopolous: 15s */ public class CalculateAverage_ianopolousfast { public static final int MAX_LINE_LENGTH = 107; public static final int MAX_STATIONS = 1 << 14; private static final OfLong LONG_LAYOUT = JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN); + private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_PREFERRED.length() >= 32 + ? ByteVector.SPECIES_256 + : ByteVector.SPECIES_128; public static void main(String[] args) throws Exception { Arena arena = Arena.global(); @@ -165,58 +173,40 @@ public class CalculateAverage_ianopolousfast { } } - public static long hasSemicolon(long d) { - // from Hacker's Delight page 92 - d = d ^ 0x3b3b3b3b3b3b3b3bL; - long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL; - return ~(y | d | 0x7f7f7f7f7f7f7f7fL); - } - - public static int getSemicolonIndex(long y) { - // from Hacker's Delight page 92 - return Long.numberOfLeadingZeros(y) >> 3; - } - static long maskHighBytes(long d, int nbytes) { return d & (-1L << ((8 - nbytes) * 8)); } public static Stat parseStation(long lineStart, MemorySegment buffer, List> stations) { - // find semicolon and update hash as we go, reading a long at a time - long d = buffer.get(LONG_LAYOUT, lineStart); - long hasSemi = hasSemicolon(d); - if (hasSemi != 0) { - int semiIndex = getSemicolonIndex(hasSemi); - d = maskHighBytes(d, semiIndex); - return dedupeStation8(lineStart, lineStart + semiIndex, d, d, buffer, stations); - } - long first8 = d; - long hash = d; + ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder()); + int keySize = line.compare(VectorOperators.EQ, ';').firstTrue(); - d = buffer.get(LONG_LAYOUT, lineStart + 8); - hasSemi = hasSemicolon(d); - if (hasSemi != 0) { - int semiIndex = getSemicolonIndex(hasSemi); - if (semiIndex == 0) - return dedupeStation8(lineStart, lineStart + 8, first8, first8, buffer, stations); - d = maskHighBytes(d, semiIndex); - return dedupeStation16(lineStart, lineStart + 8 + semiIndex, first8 ^ d, first8, d, buffer, stations); + if (keySize == BYTE_SPECIES.vectorByteSize()) { + while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') { + keySize++; + } + long first8 = buffer.get(LONG_LAYOUT, lineStart); + if (keySize < 8) + return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations); + long second8 = buffer.get(LONG_LAYOUT, lineStart + 8); + if (keySize < 16) + return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations); + long hash = first8 ^ second8; // todo include other bytes + return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations); } - int index = 8; - long second8 = d; - while (hasSemi == 0) { - hash = hash ^ d; - index += 8; - d = buffer.get(LONG_LAYOUT, lineStart + index); - hasSemi = hasSemicolon(d); + long first8 = buffer.get(LONG_LAYOUT, lineStart); + if (keySize <= 8) { + first8 = maskHighBytes(first8, keySize & 0x07); + return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations); } - int semiIndex = getSemicolonIndex(hasSemi); - d = maskHighBytes(d, semiIndex); - if (semiIndex > 0) { - hash = hash ^ d; + long second8 = buffer.get(LONG_LAYOUT, lineStart + 8); + if (keySize < 16) { + second8 = maskHighBytes(second8, keySize & 0x07); + return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations); } - return dedupeStation(lineStart, lineStart + index + semiIndex, hash, first8, second8, buffer, stations); + long hash = first8 ^ second8; // todo include later bytes + return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations); } public static int getDot(long d) { @@ -266,24 +256,30 @@ public class CalculateAverage_ianopolousfast { for (int i = 0; i < MAX_STATIONS; i++) stations.add(null); - // Handle reading the very last line in the file - // this allows us to not worry about reading a long beyond the end + // Handle reading the very last few lines in the file + // this allows us to not worry about reading beyond the end // in the inner loop (reducing branches) - // We only need to read one because the min record size is 6 bytes - // so 2nd last record must be > 8 from end + // We need at least the vector lane size bytes back if (endByte == buffer.byteSize()) { - endByte -= 2; // skip final new line - while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n') + endByte -= 1; // skip final new line + // reverse at least vector lane width + while (endByte > 0 && buffer.byteSize() - endByte < BYTE_SPECIES.vectorByteSize()) { endByte--; + while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n') + endByte--; + } if (endByte > 0) endByte++; - // copy into a 8n sized buffer to avoid reading off end - MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 4); + // copy into a larger buffer to avoid reading off end + MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + BYTE_SPECIES.vectorByteSize()); for (long i = endByte; i < buffer.byteSize(); i++) end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i)); - Stat station = parseStation(0, end, stations); - processTemperature(station.name.length + 1, end, station); + int index = 0; + while (endByte + index < buffer.byteSize()) { + Stat station = parseStation(index, end, stations); + index = (int) processTemperature(index + station.name.length + 1, end, station); + } } while (startByte < endByte) {