From af66ac145f126211e6fd16515468c258d91ac9c8 Mon Sep 17 00:00:00 2001 From: Thomas Wuerthinger Date: Wed, 10 Jan 2024 19:42:51 +0100 Subject: [PATCH] Second tuning for thomaswue * Optimize checking for collisions by doing this a long at a time always. * Use a long at a time scanning for delimiter. * Minor tuning. Now below 0.80s on Intel i9-13900K. * Add number parsing code from Quan Anh Mai. Fix name length issue. * Include suggestion from Alfonso Peterssen for another 1.5%. * Optimize hash collision check compare for ~4% gain. * Add perf stats based on latest version. --- .../onebrc/CalculateAverage_thomaswue.java | 199 +++++++++--------- 1 file changed, 103 insertions(+), 96 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java index 4dfeb11..7fd880a 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java @@ -30,28 +30,30 @@ import java.util.stream.IntStream; /** * Simple solution that memory maps the input file, then splits it into one segment per available core and uses - * sun.misc.Unsafe to directly access the mapped memory. - * - * Runs in 0.92s on my Intel i9-13900K + * sun.misc.Unsafe to directly access the mapped memory. Uses a long at a time when checking for collision. + *

+ * Runs in 0.70s on my Intel i9-13900K * Perf stats: - * 65,004,666,383 cpu_core/cycles/ - * 71,141,249,972 cpu_atom/cycles/ + * 40,622,862,783 cpu_core/cycles/ + * 48,241,929,925 cpu_atom/cycles/ */ public class CalculateAverage_thomaswue { private static final String FILE = "./measurements.txt"; // Holding the current result for a single city. private static class Result { - short min; - short max; + final long nameAddress; + long lastNameLong; + int remainingShift; + int min; + int max; long sum; int count; - final long nameAddress; private Result(long nameAddress, int value) { this.nameAddress = nameAddress; - this.min = (short) value; - this.max = (short) value; + this.min = value; + this.max = value; this.sum = value; this.count = 1; } @@ -66,8 +68,8 @@ public class CalculateAverage_thomaswue { // Accumulate another result into this one. private void add(Result other) { - min = (short) Math.min(min, other.min); - max = (short) Math.max(max, other.max); + min = Math.min(min, other.min); + max = Math.max(max, other.max); sum += other.sum; count += other.count; } @@ -81,8 +83,7 @@ public class CalculateAverage_thomaswue { // Parallel processing of segments. List> allResults = IntStream.range(0, chunks.length - 1).mapToObj(chunkIndex -> { HashMap cities = HashMap.newHashMap(1 << 10); - Result[] results = new Result[1 << 18]; - parseLoop(chunks[chunkIndex], chunks[chunkIndex + 1], results, cities); + parseLoop(chunks[chunkIndex], chunks[chunkIndex + 1], cities); return cities; }).parallel().toList(); @@ -114,69 +115,60 @@ public class CalculateAverage_thomaswue { } } - private static void parseLoop(long chunkStart, long chunkEnd, Result[] results, HashMap cities) { + private static void parseLoop(long chunkStart, long chunkEnd, HashMap cities) { + Result[] results = new Result[1 << 18]; long scanPtr = chunkStart; - byte b; while (scanPtr < chunkEnd) { long nameAddress = scanPtr; - int hash = 0; + long hash = 0; - // Skip first letter. - scanPtr++; - - // Scan for ';' delimiter, always 4 bytes at a time. - while (true) { - int nextVal = UNSAFE.getInt(scanPtr); - if ((nextVal & 0x3B) == 0x3B) { - scanPtr++; - break; + // Search for ';', one long at a time. + long word = UNSAFE.getLong(scanPtr); + int pos = findDelimiter(word); + if (pos != 8) { + scanPtr += pos; + word = word & (-1L >>> ((8 - pos - 1) << 3)); + hash ^= word; + } + else { + scanPtr += 8; + hash ^= word; + while (true) { + word = UNSAFE.getLong(scanPtr); + pos = findDelimiter(word); + if (pos != 8) { + scanPtr += pos; + word = word & (-1L >>> ((8 - pos - 1) << 3)); + hash ^= word; + break; + } + else { + scanPtr += 8; + hash ^= word; + } } - else if ((nextVal & 0x3B00) == 0x3B00) { - scanPtr += 2; - hash = hash ^ (nextVal & 0xFF); - break; - } - else if ((nextVal & 0x3B0000) == 0x3B0000) { - scanPtr += 3; - hash = hash ^ (nextVal & 0xFFFF); - break; - } - else if (((nextVal & 0x3B000000) == 0x3B000000)) { - scanPtr += 4; - hash = hash ^ (nextVal & 0xFFFFFF); - break; - } - scanPtr += 4; - hash = hash ^ nextVal; } // Save length of name for later. - int nameLength = (int) (scanPtr - nameAddress - 1); + int nameLength = (int) (scanPtr - nameAddress); + scanPtr++; - // Parse number. - int number; - byte sign = UNSAFE.getByte(scanPtr++); - if (sign == '-') { - number = UNSAFE.getByte(scanPtr++) - '0'; - if ((b = UNSAFE.getByte(scanPtr++)) != '.') { - number = number * 10 + (b - '0'); - scanPtr++; - } - number = number * 10 + (UNSAFE.getByte(scanPtr++) - '0'); - number = -number; - } - else { - number = sign - '0'; - if ((b = UNSAFE.getByte(scanPtr++)) != '.') { - number = number * 10 + (b - '0'); - scanPtr++; - } - number = number * 10 + (UNSAFE.getByte(scanPtr++) - '0'); - } + long numberWord = UNSAFE.getLong(scanPtr); + // The 4th binary digit of the ascii of a digit is 1 while + // that of the '.' is 0. This finds the decimal separator + // The value can be 12, 20, 28 + int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); + int number = convertIntoNumber(decimalSepPos, numberWord); + + // Skip past new line. + // scanPtr++; + scanPtr += (decimalSepPos >>> 3) + 3; // Final calculation for index into hash table. - int tableIndex = (((hash ^ (hash >>> 18)) & (results.length - 1))); - while (true) { + int hashAsInt = (int) (hash ^ (hash >>> 32)); + int finalHash = (hashAsInt ^ (hashAsInt >>> 18)); + int tableIndex = (finalHash & (results.length - 1)); + outer: while (true) { Result existingResult = results[tableIndex]; if (existingResult == null) { newEntry(results, cities, nameAddress, number, tableIndex, nameLength); @@ -184,35 +176,16 @@ public class CalculateAverage_thomaswue { } else { // Check for collision. - boolean result = true; int i = 0; - if ((long) nameLength >= 8) { - if (UNSAFE.getLong(existingResult.nameAddress) != UNSAFE.getLong(nameAddress)) { - result = false; - } - else { - i += 8; + for (; i < nameLength + 1 - 8; i += 8) { + if (UNSAFE.getLong(existingResult.nameAddress + i) != UNSAFE.getLong(nameAddress + i)) { + tableIndex = (tableIndex + 1) & (results.length - 1); + continue outer; } } - else if ((long) nameLength >= 4) { - if (UNSAFE.getInt(existingResult.nameAddress) != UNSAFE.getInt(nameAddress)) { - result = false; - } - else { - i += 4; - } - } - if (result) { - for (; i < (long) nameLength; ++i) { - if (UNSAFE.getByte(existingResult.nameAddress + i) != UNSAFE.getByte(nameAddress + i)) { - result = false; - break; - } - } - } - if (result) { - existingResult.min = (short) Math.min(existingResult.min, number); - existingResult.max = (short) Math.max(existingResult.max, number); + if (((existingResult.lastNameLong ^ UNSAFE.getLong(nameAddress + i)) << existingResult.remainingShift) == 0) { + existingResult.min = Math.min(existingResult.min, number); + existingResult.max = Math.max(existingResult.max, number); existingResult.sum += number; existingResult.count++; break; @@ -223,18 +196,52 @@ public class CalculateAverage_thomaswue { } } } - - // Skip new line. - scanPtr++; } } + // Special method to convert a number in the specific format into an int value without branches created by + // Quan Anh Mai. + private static int convertIntoNumber(int decimalSepPos, long numberWord) { + int shift = 28 - decimalSepPos; + // signed is -1 if negative, 0 otherwise + long signed = (~numberWord << 59) >> 63; + long designMask = ~(signed & 0xFF); + // Align the number to a specific position and transform the ascii code + // to actual digit value in each byte + long digits = ((numberWord & designMask) << shift) & 0x0F000F0F00L; + + // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit) + // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) = + // 0x000000UU00TTHH00 + + // 0x00UU00TTHH000000 * 10 + + // 0xUU00TTHH00000000 * 100 + // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400 + // This results in our value lies in the bit 32 to 41 of this product + // That was close :) + long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; + long value = (absValue ^ signed) - signed; + return (int) value; + } + + private static int findDelimiter(long word) { + long input = word ^ 0x3B3B3B3B3B3B3B3BL; + long tmp = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; + return Long.numberOfTrailingZeros(tmp) >>> 3; + } + private static void newEntry(Result[] results, HashMap cities, long nameAddress, int number, int hash, int nameLength) { Result r = new Result(nameAddress, number); results[hash] = r; byte[] bytes = new byte[nameLength]; + + int i = 0; + for (; i < nameLength + 1 - 8; i += 8) { + } + r.lastNameLong = UNSAFE.getLong(nameAddress + i); + r.remainingShift = (64 - (nameLength + 1 - i) << 3); UNSAFE.copyMemory(null, nameAddress, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); - cities.put(new String(bytes, StandardCharsets.UTF_8), r); + String nameAsString = new String(bytes, StandardCharsets.UTF_8); + cities.put(nameAsString, r); } private static long[] getSegments(int numberOfChunks) throws IOException {