From c886aaba3498fbc381a009d0f0a466a20194992b Mon Sep 17 00:00:00 2001 From: Parker Timmins <45302127+parkertimmins@users.noreply.github.com> Date: Tue, 23 Jan 2024 12:43:34 -0600 Subject: [PATCH] Deploy v2 for parkertimmins (#524) * Deploy v2 for parkertimmins Main changes: - fix hash which masked incorrectly - do station equality check in simd - make station array length multiple of 32 - search for newline rather than semicolon * Fix bug - entries were being skipped between batches At the boundary between two batches, the first batch would stop after crossing a limit with a padding of 200 characters applied. The next batch should then start looking for the first full entry after the padding. This padding logic had been removed when starting a batch. For this reason, entries starting in the 200 character padding between batches were skipped. --- .../CalculateAverage_parkertimmins.java | 175 ++++++++---------- 1 file changed, 78 insertions(+), 97 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_parkertimmins.java b/src/main/java/dev/morling/onebrc/CalculateAverage_parkertimmins.java index 71412fb..c689ff1 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_parkertimmins.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_parkertimmins.java @@ -16,28 +16,21 @@ package dev.morling.onebrc; import jdk.incubator.vector.ByteVector; -import jdk.incubator.vector.VectorMask; -import jdk.incubator.vector.VectorOperators; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; -import java.lang.reflect.Array; -import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.io.IOException; import java.io.RandomAccessFile; -import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.*; import java.util.concurrent.atomic.AtomicLong; -import java.util.zip.CRC32C; public class CalculateAverage_parkertimmins { private static final String FILE = "./measurements.txt"; - // private static final String FILE = "./full_measurements.no_license"; private static record ResultRow(double min, double mean, double max) { public String toString() { @@ -51,14 +44,16 @@ public class CalculateAverage_parkertimmins { static class OpenHashTable { static class Entry { + + // key always stored as multiple of 32 bytes byte[] key; - short min; - short max; + byte keyLen; + short min = Short.MAX_VALUE; + short max = Short.MIN_VALUE; long sum = 0; long count = 0; - int hash; - void merge(OpenHashTable.Entry other) { + void merge(Entry other) { min = (short) Math.min(min, other.min); max = (short) Math.max(max, other.max); sum += other.sum; @@ -80,15 +75,20 @@ public class CalculateAverage_parkertimmins { // key not present, so add it if (entry == null) { entry = entries[idx] = new Entry(); - entry.key = Arrays.copyOf(buf, sLen); + + int rem = sLen % 32; + int arrayLen = rem == 0 ? sLen : sLen + 32 - rem; + entry.key = Arrays.copyOf(buf, arrayLen); + Arrays.fill(entry.key, sLen, arrayLen, (byte) 0); + entry.keyLen = (byte) sLen; + entry.min = entry.max = val; entry.sum += val; entry.count++; - entry.hash = hash; break; } else { - if (entry.hash == hash && entry.key.length == sLen && Arrays.equals(entry.key, 0, sLen, buf, 0, sLen)) { + if (entry.keyLen == sLen && eq(buf, entry.key, entry.keyLen)) { entry.min = (short) Math.min(entry.min, val); entry.max = (short) Math.max(entry.max, val); entry.sum += val; @@ -103,6 +103,23 @@ public class CalculateAverage_parkertimmins { } } + static boolean eq(byte[] buf, byte[] entryKey, int sLen) { + int needed = sLen; + for (int offset = 0; offset <= 96; offset += 32) { + var a = ByteVector.fromArray(ByteVector.SPECIES_256, buf, offset); + var b = ByteVector.fromArray(ByteVector.SPECIES_256, entryKey, offset); + int matches = a.eq(b).not().firstTrue(); + if (needed <= 32) { + return matches >= needed; + } + else if (matches < 32) { + return false; + } + needed -= 32; + } + return false; + } + static long findNextEntryStart(MemorySegment ms, long offset) { long curr = offset; while (ms.get(ValueLayout.JAVA_BYTE, curr) != '\n') { @@ -112,8 +129,17 @@ public class CalculateAverage_parkertimmins { return curr; } - static short[] digits10s = { 0, 100, 200, 300, 400, 500, 600, 700, 800, 900 }; - static short[] digits1s = { 0, 10, 20, 30, 40, 50, 60, 70, 80, 90 }; + static short[] digits2s = new short[256]; + static short[] digits1s = new short[256]; + static short[] digits0s = new short[256]; + + static { + for (int i = 0; i < 10; ++i) { + digits2s[i + ((int) '0')] = (short) (i * 100); + digits1s[i + ((int) '0')] = (short) (i * 10); + digits0s[i + ((int) '0')] = (short) i; + } + } static void processRangeScalar(MemorySegment ms, long start, long end, final OpenHashTable localAgg) { byte[] buf = new byte[128]; @@ -139,9 +165,10 @@ public class CalculateAverage_parkertimmins { boolean neg = ms.get(ValueLayout.JAVA_BYTE, tempIdx) == '-'; boolean twoDig = ms.get(ValueLayout.JAVA_BYTE, tempIdx + 1 + (neg ? 1 : 0)) == '.'; int len = 3 + (neg ? 1 : 0) + (twoDig ? 0 : 1); - int d0 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 1)) - '0'; - int d1 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 3)) - '0'; - int base = d0 + digits1s[d1] + (twoDig ? 0 : digits10s[((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 4)) - '0']); + int d0 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 1)); + int d1 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 3)); + int d2 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 4)); // could be - or \n + int base = digits0s[d0] + digits1s[d1] + digits2s[d2]; short temp = (short) (neg ? -base : base); localAgg.add(buf, sLen, temp, hash); @@ -150,100 +177,55 @@ public class CalculateAverage_parkertimmins { } static int hash(byte[] buf, int sLen) { - // TODO find a hash that works directly from byte array - // if shorter than 8 chars, mask out upper bits - long mask = sLen < 8 ? -(1L << ((8 - sLen) << 3)) : 0xFFFFFFFFL; - long val = ((buf[0] & 0xffL) << 56) | ((buf[1] & 0xffL) << 48) | ((buf[2] & 0xffL) << 40) | ((buf[3] & 0xffL) << 32) | ((buf[4] & 0xffL) << 24) - | ((buf[5] & 0xffL) << 16) | ((buf[6] & 0xFFL) << 8) | (buf[7] & 0xffL); + int shift = Math.max(0, 8 - sLen) << 3; + long mask = (~0L) >>> shift; + long val = ((buf[7] & 0xffL) << 56) | ((buf[6] & 0xffL) << 48) | ((buf[5] & 0xffL) << 40) | ((buf[4] & 0xffL) << 32) | ((buf[3] & 0xffL) << 24) + | ((buf[2] & 0xffL) << 16) | ((buf[1] & 0xFFL) << 8) | (buf[0] & 0xffL); val &= mask; - - // also worth trying: https://lemire.me/blog/2015/10/22/faster-hashing-without-effort/ // lemire: https://lemire.me/blog/2023/07/14/recognizing-string-prefixes-with-simd-instructions/ int hash = (int) (((((val >> 32) ^ val) & 0xffffffffL) * 3523216699L) >> 32); return hash; } - static void processRangeSIMD(MemorySegment ms, boolean frontPad, boolean backPad, long start, long end, final OpenHashTable localAgg) { + static void processRangeSIMD(MemorySegment ms, boolean isFirst, boolean isLast, long start, long end, final OpenHashTable localAgg) { byte[] buf = new byte[128]; - long curr = frontPad ? findNextEntryStart(ms, start) : start; - long limit = end - padding; + long curr = isFirst ? start : findNextEntryStart(ms, start); + long limit = isLast ? end - padding : end; - var needle = ByteVector.broadcast(ByteVector.SPECIES_256, ';'); while (curr < limit) { - - int segStart = 0; - int sLen; - - while (true) { - var section = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, ms, curr + segStart, ByteOrder.LITTLE_ENDIAN); - section.intoArray(buf, segStart); - VectorMask matches = section.compare(VectorOperators.EQ, needle); - int idx = matches.firstTrue(); + int nl = 0; + for (int offset = 0; offset < 128; offset += 32) { + ByteVector section = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, ms, curr + offset, ByteOrder.LITTLE_ENDIAN); + section.intoArray(buf, offset); + var idx = section.eq((byte) '\n').firstTrue(); if (idx != 32) { - sLen = segStart + idx; + nl = offset + idx; break; } - segStart += 32; } - int hash = hash(buf, sLen); - - curr += sLen; - curr++; // semicolon - - long tempIdx = curr; - boolean neg = ms.get(ValueLayout.JAVA_BYTE, tempIdx) == '-'; - boolean twoDig = ms.get(ValueLayout.JAVA_BYTE, tempIdx + 1 + (neg ? 1 : 0)) == '.'; - int len = 3 + (neg ? 1 : 0) + (twoDig ? 0 : 1); - int d0 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 1)) - '0'; - int d1 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 3)) - '0'; - int base = d0 + digits1s[d1] + (twoDig ? 0 : digits10s[((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 4)) - '0']); + int nl1 = buf[nl - 1]; + int nl3 = buf[nl - 3]; + int nl4 = buf[nl - 4]; + int nl5 = buf[nl - 5]; + int base = (nl1 - '0') + 10 * (nl3 - '0') + digits2s[nl4]; + boolean neg = nl4 == '-' || (nl4 != ';' && nl5 == '-'); short temp = (short) (neg ? -base : base); + int tempLen = 4 + (neg ? 1 : 0) + (base >= 100 ? 1 : 0); + int semi = nl - tempLen; - localAgg.add(buf, sLen, temp, hash); - curr = tempIdx + len + 1; + int hash = hash(buf, semi); + localAgg.add(buf, semi, temp, hash); + curr += (nl + 1); } // last batch is near end of file, process without SIMD to avoid out-of-bounds - if (!backPad) { + if (isLast) { processRangeScalar(ms, curr, end, localAgg); } } - /** - * For debugging issues with hash function - */ - static void checkHashDistributionQuality(ArrayList localAggs) { - HashSet uniquesHashValues = new HashSet(); - HashSet uniqueCities = new HashSet(); - HashMap> cityToHash = new HashMap<>(); - - for (var agg : localAggs) { - for (OpenHashTable.Entry entry : agg.entries) { - if (entry == null) { - continue; - } - uniquesHashValues.add(entry.hash); - String station = new String(entry.key, StandardCharsets.UTF_8); // for UTF-8 encoding - uniqueCities.add(station); - - if (!cityToHash.containsKey(station)) { - cityToHash.put(station, new HashSet<>()); - } - cityToHash.get(station).add(entry.hash); - } - } - - for (var pair : cityToHash.entrySet()) { - if (pair.getValue().size() > 1) { - System.err.println("multiple hashes: " + pair.getKey() + " " + pair.getValue()); - } - } - - System.err.println("Unique stations: " + uniqueCities.size() + ", unique hash values: " + uniquesHashValues.size()); - } - /** * Combine thread local values */ @@ -254,7 +236,7 @@ public class CalculateAverage_parkertimmins { if (entry == null) { continue; } - String station = new String(entry.key, StandardCharsets.UTF_8); // for UTF-8 encoding + String station = new String(entry.key, 0, entry.keyLen, StandardCharsets.UTF_8); // for UTF-8 encoding var currentVal = global.get(station); if (currentVal != null) { currentVal.merge(entry); @@ -267,8 +249,6 @@ public class CalculateAverage_parkertimmins { return global; } - static final long batchSize = 10_000_000; - static final int padding = 200; // max entry size is 107ish == 100 (station) + 1 (semicolon) + 5 (temp, eg -99.9) + 1 (newline) public static void main(String[] args) throws IOException, InterruptedException { @@ -277,7 +257,10 @@ public class CalculateAverage_parkertimmins { int numThreads = Runtime.getRuntime().availableProcessors(); + final long batchSize = 10_000_000; + final long fileSize = channel.size(); + // final long batchSize = fileSize / numThreads + 1; final MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()); final ArrayList localAggs = new ArrayList<>(numThreads); Thread[] threads = new Thread[numThreads]; @@ -299,11 +282,9 @@ public class CalculateAverage_parkertimmins { break; } final long endBatch = Math.min(startBatch + batchSize, fileSize); - final boolean first = startBatch == 0; - final boolean frontPad = !first; - final boolean last = endBatch == fileSize; - final boolean backPad = !last; - processRangeSIMD(ms, frontPad, backPad, startBatch, endBatch, localAgg); + final boolean isFirstBatch = startBatch == 0; + final boolean isLastBatch = endBatch == fileSize; + processRangeSIMD(ms, isFirstBatch, isLastBatch, startBatch, endBatch, localAgg); } } }