From f6aa09926c1f4cf30c47737f3fa37c56df89ea11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Serkan=20=C3=96ZAL?= Date: Wed, 31 Jan 2024 11:56:11 +0300 Subject: [PATCH] serkan-ozal's 6th submission: (#667) - process multiple lines at a time to get the benefit of ILP (Instruction Level Parallelism) better --- .../onebrc/CalculateAverage_serkan_ozal.java | 356 ++++++++++++------ 1 file changed, 234 insertions(+), 122 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java b/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java index 0ec4856..5325816 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java @@ -59,7 +59,7 @@ public class CalculateAverage_serkan_ozal { ? ByteVector.SPECIES_128 : ByteVector.SPECIES_64; private static final int BYTE_SPECIES_SIZE = BYTE_SPECIES.vectorByteSize(); - private static final MemorySegment ALL = MemorySegment.NULL.reinterpret(Long.MAX_VALUE); + private static final MemorySegment NULL = MemorySegment.NULL.reinterpret(Long.MAX_VALUE); private static final ByteOrder NATIVE_BYTE_ORDER = ByteOrder.nativeOrder(); private static final char NEW_LINE_SEPARATOR = '\n'; @@ -290,7 +290,7 @@ public class CalculateAverage_serkan_ozal { long regionStart = regionGiven ? (r.address() + task.start) : r.address(); long regionEnd = regionStart + task.size; - doProcessRegion(r, r.address(), regionStart, regionEnd); + doProcessRegion(regionStart, regionEnd); } if (VERBOSE) { @@ -334,86 +334,220 @@ public class CalculateAverage_serkan_ozal { } } - private void doProcessRegion(MemorySegment region, long regionAddress, long regionStart, long regionEnd) { - final int vectorSize = BYTE_SPECIES.vectorByteSize(); - final long regionMainLimit = regionEnd - BYTE_SPECIES_SIZE; + private long findClosestLineEnd(long endPos, long minPos) { + int i = 0; + int maxI = Math.min(MAX_LINE_LENGTH, (int) (endPos - minPos)); + while (i <= maxI && U.getByte(endPos - i) != NEW_LINE_SEPARATOR) { + i++; + } + return endPos - i + 1; + } - long regionPtr; + // Credits: merykitty + private long extractValue(long regionPtr, long word, OpenMap map, int entryOffset) { + // Parse and extract value + int decimalSepPos = Long.numberOfTrailingZeros(~word & 0x10101000); + int shift = 28 - decimalSepPos; + long signed = (~word << 59) >> 63; + long designMask = ~(signed & 0xFF); + long digits = ((word & designMask) << shift) & 0x0F000F0F00L; + long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; + int value = (int) ((absValue ^ signed) - signed); + + // Put extracted value into map + map.putValue(entryOffset, value); + + // Return new position + return regionPtr + (decimalSepPos >>> 3) + 3; + } + + private void doProcessRegion(long regionStart, long regionEnd) { + final int vectorSize = BYTE_SPECIES.vectorByteSize(); + + final long size = regionEnd - regionStart; + final long segmentSize = size / 2; + + final long regionStart1 = regionStart; + final long regionEnd1 = Math.max(regionStart1, findClosestLineEnd(regionStart1 + segmentSize, regionStart)); + + final long regionStart2 = regionEnd1; + final long regionEnd2 = regionEnd; + + long regionPtr1, regionPtr2; // Read and process region - main - for (regionPtr = regionStart; regionPtr < regionMainLimit;) { - regionPtr = doProcessLine(regionPtr, vectorSize); + // Inspired by: @jerrinot + // - two lines at a time (according to my experiment, this is optimum value in terms of register spilling) + // - most of the implementation is inlined + // - so get the benefit of ILP (Instruction Level Parallelism) better + for (regionPtr1 = regionStart1, regionPtr2 = regionStart2; regionPtr1 < regionEnd1 && regionPtr2 < regionEnd2;) { + // Search key/value separators and find keys' start and end positions + //////////////////////////////////////////////////////////////////////////////////////////////////////// + long keyStartPtr1 = regionPtr1; + long keyStartPtr2 = regionPtr2; + + ByteVector keyVector1 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr1, NATIVE_BYTE_ORDER); + ByteVector keyVector2 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr2, NATIVE_BYTE_ORDER); + + int keyLength1 = keyVector1.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue(); + int keyLength2 = keyVector2.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue(); + + if (keyLength1 != vectorSize && keyLength2 != vectorSize) { + regionPtr1 += (keyLength1 + 1); + regionPtr2 += (keyLength2 + 1); + } + else { + if (keyLength1 != vectorSize) { + regionPtr1 += (keyLength1 + 1); + } + else { + regionPtr1 += vectorSize; + for (; U.getByte(regionPtr1) != KEY_VALUE_SEPARATOR; regionPtr1++) + ; + keyLength1 = (int) (regionPtr1 - keyStartPtr1); + regionPtr1++; + } + if (keyLength2 != vectorSize) { + regionPtr2 += (keyLength2 + 1); + } + else { + regionPtr2 += vectorSize; + for (; U.getByte(regionPtr2) != KEY_VALUE_SEPARATOR; regionPtr2++) + ; + keyLength2 = (int) (regionPtr2 - keyStartPtr2); + regionPtr2++; + } + } + + // Read first words as they will be used while extracting values later + long word1 = U.getLong(regionPtr1); + long word2 = U.getLong(regionPtr2); + if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) { + word1 = Long.reverseBytes(word1); + word2 = Long.reverseBytes(word2); + } + //////////////////////////////////////////////////////////////////////////////////////////////////////// + + // Calculate key hashes and find entry indexes + //////////////////////////////////////////////////////////////////////////////////////////////////////// + int x1, y1, x2, y2; + if (keyLength1 >= Integer.BYTES && keyLength2 >= Integer.BYTES) { + x1 = U.getInt(keyStartPtr1); + y1 = U.getInt(keyStartPtr1 + keyLength1 - Integer.BYTES); + x2 = U.getInt(keyStartPtr2); + y2 = U.getInt(keyStartPtr2 + keyLength2 - Integer.BYTES); + } + else { + if (keyLength1 >= Integer.BYTES) { + x1 = U.getInt(keyStartPtr1); + y1 = U.getInt(keyStartPtr1 + keyLength1 - Integer.BYTES); + } + else { + x1 = U.getByte(keyStartPtr1); + y1 = U.getByte(keyStartPtr1 + keyLength1 - Byte.BYTES); + } + if (keyLength2 >= Integer.BYTES) { + x2 = U.getInt(keyStartPtr2); + y2 = U.getInt(keyStartPtr2 + keyLength2 - Integer.BYTES); + } + else { + x2 = U.getByte(keyStartPtr2); + y2 = U.getByte(keyStartPtr2 + keyLength2 - Byte.BYTES); + } + } + + int keyHash1 = (Integer.rotateLeft(x1 * OpenMap.HASH_SEED, OpenMap.HASH_ROTATE) ^ y1) * OpenMap.HASH_SEED; + int keyHash2 = (Integer.rotateLeft(x2 * OpenMap.HASH_SEED, OpenMap.HASH_ROTATE) ^ y2) * OpenMap.HASH_SEED; + + int entryIdx1 = (keyHash1 & OpenMap.ENTRY_HASH_MASK) << OpenMap.ENTRY_SIZE_SHIFT; + int entryIdx2 = (keyHash2 & OpenMap.ENTRY_HASH_MASK) << OpenMap.ENTRY_SIZE_SHIFT; + //////////////////////////////////////////////////////////////////////////////////////////////////////// + + // Put keys and calculate entry offsets to put values + //////////////////////////////////////////////////////////////////////////////////////////////////////// + int entryOffset1 = map.putKey(keyVector1, keyStartPtr1, keyLength1, entryIdx1); + int entryOffset2 = map.putKey(keyVector2, keyStartPtr2, keyLength2, entryIdx2); + //////////////////////////////////////////////////////////////////////////////////////////////////////// + + // Extract values by parsing and put them into map + //////////////////////////////////////////////////////////////////////////////////////////////////////// + regionPtr1 = extractValue(regionPtr1, word1, map, entryOffset1); + regionPtr2 = extractValue(regionPtr2, word2, map, entryOffset2); + //////////////////////////////////////////////////////////////////////////////////////////////////////// } // Read and process region - tail - for (long i = regionPtr, j = regionPtr; i < regionEnd;) { - byte b = U.getByte(i); - if (b == KEY_VALUE_SEPARATOR) { - long baseOffset = map.putKey(null, j, (int) (i - j)); - i = extractValue(i + 1, map, baseOffset); - j = i; + doProcessTail(regionPtr1, regionEnd1, regionPtr2, regionEnd2, vectorSize); + } + + private void doProcessTail(long regionPtr1, long regionEnd1, long regionPtr2, long regionEnd2, int vectorSize) { + while (regionPtr1 < regionEnd1) { + long keyStartPtr1 = regionPtr1; + ByteVector keyVector1 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr1, NATIVE_BYTE_ORDER); + int keyLength1 = keyVector1.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue(); + if (keyLength1 != vectorSize) { + regionPtr1 += (keyLength1 + 1); } else { - i++; + regionPtr1 += vectorSize; + for (; U.getByte(regionPtr1) != KEY_VALUE_SEPARATOR; regionPtr1++) + ; + keyLength1 = (int) (regionPtr1 - keyStartPtr1); + regionPtr1++; } + int entryIdx1 = map.calculateEntryIndex(keyStartPtr1, keyLength1); + int entryOffset1 = map.putKey(keyVector1, keyStartPtr1, keyLength1, entryIdx1); + long word1 = U.getLong(regionPtr1); + if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) { + word1 = Long.reverseBytes(word1); + } + regionPtr1 = extractValue(regionPtr1, word1, map, entryOffset1); + } + while (regionPtr2 < regionEnd2) { + long keyStartPtr2 = regionPtr2; + ByteVector keyVector2 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr2, NATIVE_BYTE_ORDER); + int keyLength2 = keyVector2.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue(); + if (keyLength2 != vectorSize) { + regionPtr2 += (keyLength2 + 1); + } + else { + regionPtr2 += vectorSize; + for (; U.getByte(regionPtr2) != KEY_VALUE_SEPARATOR; regionPtr2++) + ; + keyLength2 = (int) (regionPtr2 - keyStartPtr2); + regionPtr2++; + } + int entryIdx2 = map.calculateEntryIndex(keyStartPtr2, keyLength2); + int entryOffset2 = map.putKey(keyVector2, keyStartPtr2, keyLength2, entryIdx2); + long word2 = U.getLong(regionPtr2); + if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) { + word2 = Long.reverseBytes(word2); + } + regionPtr2 = extractValue(regionPtr2, word2, map, entryOffset2); } } - private long doProcessLine(long regionPtr, int vectorSize) { - // Find key/value separator - //////////////////////////////////////////////////////////////////////////////////////////////////////// - long keyStartPtr = regionPtr; - - // Vectorized search for key/value separator - ByteVector keyVector = ByteVector.fromMemorySegment(BYTE_SPECIES, ALL, regionPtr, NATIVE_BYTE_ORDER); - - int keyLength = keyVector.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue(); - // Check whether key/value separator is found in the first vector (city name is <= vector size) - if (keyLength != vectorSize) { - regionPtr += (keyLength + 1); - } - else { - regionPtr += vectorSize; - for (; U.getByte(regionPtr) != KEY_VALUE_SEPARATOR; regionPtr++) - ; - keyLength = (int) (regionPtr - keyStartPtr); - regionPtr++; - // I have tried vectorized search for key/value separator in the remaining part, - // but since majority (99%) of the city names <= 16 bytes - // and other a few longer city names (have length < 16 and <= 32) not close to 32 bytes, - // byte by byte search is better in terms of performance (according to my experiments) and simplicity. - } - //////////////////////////////////////////////////////////////////////////////////////////////////////// - - // Put key and get map offset to put value - long entryOffset = map.putKey(keyVector, keyStartPtr, keyLength); - - // Extract value, put it into map and return next position in the region to continue processing from there - return extractValue(regionPtr, map, entryOffset); - } } - // Credits: merykitty - private static long extractValue(long regionPtr, OpenMap map, long entryOffset) { - long word = U.getLong(regionPtr); - if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) { - word = Long.reverseBytes(word); + /** + * Region processor task + */ + private static final class Task { + + private final FileChannel fileChannel; + private final MemorySegment region; + private final long start; + private final long end; + private final long size; + + private Task(FileChannel fileChannel, MemorySegment region, long start, long end) { + this.fileChannel = fileChannel; + this.region = region; + this.start = start; + this.end = end; + this.size = end - start; } - // Parse and extract value - int decimalSepPos = Long.numberOfTrailingZeros(~word & 0x10101000); - int shift = 28 - decimalSepPos; - long signed = (~word << 59) >> 63; - long designMask = ~(signed & 0xFF); - long digits = ((word & designMask) << shift) & 0x0F000F0F00L; - long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; - int value = (int) ((absValue ^ signed) - signed); - - // Put extracted value into map - map.putValue(entryOffset, value); - - // Return new position - return regionPtr + (decimalSepPos >>> 3) + 3; } /** @@ -433,24 +567,6 @@ public class CalculateAverage_serkan_ozal { } - private static final class Task { - - private final FileChannel fileChannel; - private final MemorySegment region; - private final long start; - private final long end; - private final long size; - - private Task(FileChannel fileChannel, MemorySegment region, long start, long end) { - this.fileChannel = fileChannel; - this.region = region; - this.start = start; - this.end = end; - this.size = end - start; - } - - } - /** * Region processor response */ @@ -555,6 +671,9 @@ public class CalculateAverage_serkan_ozal { } + /** + * Custom map implementation to store results + */ private static final class OpenMap { // Layout @@ -585,21 +704,22 @@ public class CalculateAverage_serkan_ozal { private static final int ENTRY_MASK = MAP_SIZE - 1; private static final int KEY_ARRAY_OFFSET = KEY_OFFSET - Unsafe.ARRAY_BYTE_BASE_OFFSET; + private static final int HASH_SEED = 0x9E3779B9; + private static final int HASH_ROTATE = 5; + private final byte[] data; - private final long[] entryOffsets; + private final int[] entryOffsets; private int entryOffsetIdx; private OpenMap() { this.data = new byte[MAP_SIZE]; // Max number of unique keys are 10K, so 1 << 14 (16384) is long enough to hold offsets for all of them - this.entryOffsets = new long[1 << 14]; + this.entryOffsets = new int[1 << 14]; this.entryOffsetIdx = 0; } // Credits: merykitty - private static int calculateKeyHash(long address, int keyLength) { - int seed = 0x9E3779B9; - int rotate = 5; + private int calculateEntryIndex(long address, int keyLength) { int x, y; if (keyLength >= Integer.BYTES) { x = U.getInt(address); @@ -609,19 +729,17 @@ public class CalculateAverage_serkan_ozal { x = U.getByte(address); y = U.getByte(address + keyLength - Byte.BYTES); } - return (Integer.rotateLeft(x * seed, rotate) ^ y) * seed; + // Calculate key hash + int keyHash = (Integer.rotateLeft(x * HASH_SEED, HASH_ROTATE) ^ y) * HASH_SEED; + // Get the position of the entry in the linear map based on calculated hash + return (keyHash & ENTRY_HASH_MASK) << ENTRY_SIZE_SHIFT; } - private long putKey(ByteVector keyVector, long keyStartAddress, int keyLength) { - // Calculate hash of key - int keyHash = calculateKeyHash(keyStartAddress, keyLength); - // and get the position of the entry in the linear map based on calculated hash - int idx = (keyHash & ENTRY_HASH_MASK) << ENTRY_SIZE_SHIFT; - + private int putKey(ByteVector keyVector, long keyStartAddress, int keyLength, int entryIdx) { // Start searching from the calculated position // and continue until find an available slot in case of hash collision // TODO Prevent infinite loop if all the slots are in use for other keys - for (long entryOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET + idx;; entryOffset = (entryOffset + ENTRY_SIZE) & ENTRY_MASK) { + for (int entryOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET + entryIdx;; entryOffset = (entryOffset + ENTRY_SIZE) & ENTRY_MASK) { int keySize = U.getInt(data, entryOffset + KEY_SIZE_OFFSET); // Check whether current index is empty (no another key is inserted yet) if (keySize == 0) { @@ -633,32 +751,26 @@ public class CalculateAverage_serkan_ozal { entryOffsets[entryOffsetIdx++] = entryOffset; return entryOffset; } - int keyStartArrayOffset = (int) entryOffset + KEY_ARRAY_OFFSET; // Check for hash collision (hashes are same, but keys are different). // If there is no collision (both hashes and keys are equals), return current slot's offset. // Otherwise, continue iterating until find an available slot. - if (keySize == keyLength && keysEqual(keyVector, keyStartAddress, keyLength, keyStartArrayOffset)) { + if (keySize == keyLength && keysEqual(keyVector, keyStartAddress, keyLength, entryOffset + KEY_ARRAY_OFFSET)) { return entryOffset; } } } private boolean keysEqual(ByteVector keyVector, long keyStartAddress, int keyLength, int keyStartArrayOffset) { - int keyCheckIdx = 0; - if (keyVector != null) { - // Use vectorized search for the comparison of keys. - // Since majority of the city names >= 8 bytes and <= 16 bytes, - // this way is more efficient (according to my experiments) than any other comparisons (byte by byte or 2 longs). - ByteVector entryKeyVector = ByteVector.fromArray(BYTE_SPECIES, data, keyStartArrayOffset); - long eqMask = keyVector.compare(VectorOperators.EQ, entryKeyVector).toLong(); - int eqCount = Long.numberOfTrailingZeros(~eqMask); - if (eqCount >= keyLength) { - return true; - } - else if (keyLength <= BYTE_SPECIES_SIZE) { - return false; - } - keyCheckIdx = BYTE_SPECIES_SIZE; + // Use vectorized search for the comparison of keys. + // Since majority of the city names >= 8 bytes and <= 16 bytes, + // this way is more efficient (according to my experiments) than any other comparisons (byte by byte or 2 longs). + ByteVector entryKeyVector = ByteVector.fromArray(BYTE_SPECIES, data, keyStartArrayOffset); + int eqCount = keyVector.compare(VectorOperators.EQ, entryKeyVector).trueCount(); + if (eqCount == keyLength) { + return true; + } + else if (keyLength <= BYTE_SPECIES_SIZE) { + return false; } // Compare remaining parts of the keys @@ -671,7 +783,7 @@ public class CalculateAverage_serkan_ozal { long keyStartOffset = keyStartArrayOffset + Unsafe.ARRAY_BYTE_BASE_OFFSET; int alignedKeyLength = normalizedKeyLength & 0xFFFFFFF8; int i; - for (i = keyCheckIdx; i < alignedKeyLength; i += Long.BYTES) { + for (i = BYTE_SPECIES_SIZE; i < alignedKeyLength; i += Long.BYTES) { if (U.getLong(keyStartAddress + i) != U.getLong(data, keyStartOffset + i)) { return false; } @@ -690,18 +802,18 @@ public class CalculateAverage_serkan_ozal { return wordA == wordB; } - private void putValue(long entryOffset, int value) { - long countOffset = entryOffset + COUNT_OFFSET; + private void putValue(int entryOffset, int value) { + int countOffset = entryOffset + COUNT_OFFSET; U.putInt(data, countOffset, U.getInt(data, countOffset) + 1); - long minValueOffset = entryOffset + MIN_VALUE_OFFSET; + int minValueOffset = entryOffset + MIN_VALUE_OFFSET; if (value < U.getShort(data, minValueOffset)) { U.putShort(data, minValueOffset, (short) value); } - long maxValueOffset = entryOffset + MAX_VALUE_OFFSET; + int maxValueOffset = entryOffset + MAX_VALUE_OFFSET; if (value > U.getShort(data, maxValueOffset)) { U.putShort(data, maxValueOffset, (short) value); } - long sumOffset = entryOffset + VALUE_SUM_OFFSET; + int sumOffset = entryOffset + VALUE_SUM_OFFSET; U.putLong(data, sumOffset, U.getLong(data, sumOffset) + value); } @@ -709,13 +821,13 @@ public class CalculateAverage_serkan_ozal { // Merge this local map into global result map Arrays.sort(entryOffsets, 0, entryOffsetIdx); for (int i = 0; i < entryOffsetIdx; i++) { - long entryOffset = entryOffsets[i]; + int entryOffset = entryOffsets[i]; int keyLength = U.getInt(data, entryOffset + KEY_SIZE_OFFSET); if (keyLength == 0) { // No entry is available for this index, so continue iterating continue; } - int entryArrayIdx = (int) (entryOffset + KEY_OFFSET - Unsafe.ARRAY_BYTE_BASE_OFFSET); + int entryArrayIdx = entryOffset + KEY_OFFSET - Unsafe.ARRAY_BYTE_BASE_OFFSET; String key = new String(data, entryArrayIdx, keyLength, StandardCharsets.UTF_8); int count = U.getInt(data, entryOffset + COUNT_OFFSET); short minValue = U.getShort(data, entryOffset + MIN_VALUE_OFFSET);