serkan-ozal's 6th submission: (#667)

- process multiple lines at a time to get the benefit of ILP (Instruction Level Parallelism) better
This commit is contained in:
Serkan ÖZAL 2024-01-31 11:56:11 +03:00 committed by GitHub
parent 1a4ac0d249
commit f6aa09926c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -59,7 +59,7 @@ public class CalculateAverage_serkan_ozal {
? ByteVector.SPECIES_128
: ByteVector.SPECIES_64;
private static final int BYTE_SPECIES_SIZE = BYTE_SPECIES.vectorByteSize();
private static final MemorySegment ALL = MemorySegment.NULL.reinterpret(Long.MAX_VALUE);
private static final MemorySegment NULL = MemorySegment.NULL.reinterpret(Long.MAX_VALUE);
private static final ByteOrder NATIVE_BYTE_ORDER = ByteOrder.nativeOrder();
private static final char NEW_LINE_SEPARATOR = '\n';
@ -290,7 +290,7 @@ public class CalculateAverage_serkan_ozal {
long regionStart = regionGiven ? (r.address() + task.start) : r.address();
long regionEnd = regionStart + task.size;
doProcessRegion(r, r.address(), regionStart, regionEnd);
doProcessRegion(regionStart, regionEnd);
}
if (VERBOSE) {
@ -334,72 +334,17 @@ public class CalculateAverage_serkan_ozal {
}
}
private void doProcessRegion(MemorySegment region, long regionAddress, long regionStart, long regionEnd) {
final int vectorSize = BYTE_SPECIES.vectorByteSize();
final long regionMainLimit = regionEnd - BYTE_SPECIES_SIZE;
long regionPtr;
// Read and process region - main
for (regionPtr = regionStart; regionPtr < regionMainLimit;) {
regionPtr = doProcessLine(regionPtr, vectorSize);
}
// Read and process region - tail
for (long i = regionPtr, j = regionPtr; i < regionEnd;) {
byte b = U.getByte(i);
if (b == KEY_VALUE_SEPARATOR) {
long baseOffset = map.putKey(null, j, (int) (i - j));
i = extractValue(i + 1, map, baseOffset);
j = i;
}
else {
private long findClosestLineEnd(long endPos, long minPos) {
int i = 0;
int maxI = Math.min(MAX_LINE_LENGTH, (int) (endPos - minPos));
while (i <= maxI && U.getByte(endPos - i) != NEW_LINE_SEPARATOR) {
i++;
}
}
}
private long doProcessLine(long regionPtr, int vectorSize) {
// Find key/value separator
////////////////////////////////////////////////////////////////////////////////////////////////////////
long keyStartPtr = regionPtr;
// Vectorized search for key/value separator
ByteVector keyVector = ByteVector.fromMemorySegment(BYTE_SPECIES, ALL, regionPtr, NATIVE_BYTE_ORDER);
int keyLength = keyVector.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
// Check whether key/value separator is found in the first vector (city name is <= vector size)
if (keyLength != vectorSize) {
regionPtr += (keyLength + 1);
}
else {
regionPtr += vectorSize;
for (; U.getByte(regionPtr) != KEY_VALUE_SEPARATOR; regionPtr++)
;
keyLength = (int) (regionPtr - keyStartPtr);
regionPtr++;
// I have tried vectorized search for key/value separator in the remaining part,
// but since majority (99%) of the city names <= 16 bytes
// and other a few longer city names (have length < 16 and <= 32) not close to 32 bytes,
// byte by byte search is better in terms of performance (according to my experiments) and simplicity.
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
// Put key and get map offset to put value
long entryOffset = map.putKey(keyVector, keyStartPtr, keyLength);
// Extract value, put it into map and return next position in the region to continue processing from there
return extractValue(regionPtr, map, entryOffset);
}
return endPos - i + 1;
}
// Credits: merykitty
private static long extractValue(long regionPtr, OpenMap map, long entryOffset) {
long word = U.getLong(regionPtr);
if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
word = Long.reverseBytes(word);
}
private long extractValue(long regionPtr, long word, OpenMap map, int entryOffset) {
// Parse and extract value
int decimalSepPos = Long.numberOfTrailingZeros(~word & 0x10101000);
int shift = 28 - decimalSepPos;
@ -416,23 +361,177 @@ public class CalculateAverage_serkan_ozal {
return regionPtr + (decimalSepPos >>> 3) + 3;
}
private void doProcessRegion(long regionStart, long regionEnd) {
final int vectorSize = BYTE_SPECIES.vectorByteSize();
final long size = regionEnd - regionStart;
final long segmentSize = size / 2;
final long regionStart1 = regionStart;
final long regionEnd1 = Math.max(regionStart1, findClosestLineEnd(regionStart1 + segmentSize, regionStart));
final long regionStart2 = regionEnd1;
final long regionEnd2 = regionEnd;
long regionPtr1, regionPtr2;
// Read and process region - main
// Inspired by: @jerrinot
// - two lines at a time (according to my experiment, this is optimum value in terms of register spilling)
// - most of the implementation is inlined
// - so get the benefit of ILP (Instruction Level Parallelism) better
for (regionPtr1 = regionStart1, regionPtr2 = regionStart2; regionPtr1 < regionEnd1 && regionPtr2 < regionEnd2;) {
// Search key/value separators and find keys' start and end positions
////////////////////////////////////////////////////////////////////////////////////////////////////////
long keyStartPtr1 = regionPtr1;
long keyStartPtr2 = regionPtr2;
ByteVector keyVector1 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr1, NATIVE_BYTE_ORDER);
ByteVector keyVector2 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr2, NATIVE_BYTE_ORDER);
int keyLength1 = keyVector1.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
int keyLength2 = keyVector2.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
if (keyLength1 != vectorSize && keyLength2 != vectorSize) {
regionPtr1 += (keyLength1 + 1);
regionPtr2 += (keyLength2 + 1);
}
else {
if (keyLength1 != vectorSize) {
regionPtr1 += (keyLength1 + 1);
}
else {
regionPtr1 += vectorSize;
for (; U.getByte(regionPtr1) != KEY_VALUE_SEPARATOR; regionPtr1++)
;
keyLength1 = (int) (regionPtr1 - keyStartPtr1);
regionPtr1++;
}
if (keyLength2 != vectorSize) {
regionPtr2 += (keyLength2 + 1);
}
else {
regionPtr2 += vectorSize;
for (; U.getByte(regionPtr2) != KEY_VALUE_SEPARATOR; regionPtr2++)
;
keyLength2 = (int) (regionPtr2 - keyStartPtr2);
regionPtr2++;
}
}
// Read first words as they will be used while extracting values later
long word1 = U.getLong(regionPtr1);
long word2 = U.getLong(regionPtr2);
if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
word1 = Long.reverseBytes(word1);
word2 = Long.reverseBytes(word2);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
// Calculate key hashes and find entry indexes
////////////////////////////////////////////////////////////////////////////////////////////////////////
int x1, y1, x2, y2;
if (keyLength1 >= Integer.BYTES && keyLength2 >= Integer.BYTES) {
x1 = U.getInt(keyStartPtr1);
y1 = U.getInt(keyStartPtr1 + keyLength1 - Integer.BYTES);
x2 = U.getInt(keyStartPtr2);
y2 = U.getInt(keyStartPtr2 + keyLength2 - Integer.BYTES);
}
else {
if (keyLength1 >= Integer.BYTES) {
x1 = U.getInt(keyStartPtr1);
y1 = U.getInt(keyStartPtr1 + keyLength1 - Integer.BYTES);
}
else {
x1 = U.getByte(keyStartPtr1);
y1 = U.getByte(keyStartPtr1 + keyLength1 - Byte.BYTES);
}
if (keyLength2 >= Integer.BYTES) {
x2 = U.getInt(keyStartPtr2);
y2 = U.getInt(keyStartPtr2 + keyLength2 - Integer.BYTES);
}
else {
x2 = U.getByte(keyStartPtr2);
y2 = U.getByte(keyStartPtr2 + keyLength2 - Byte.BYTES);
}
}
int keyHash1 = (Integer.rotateLeft(x1 * OpenMap.HASH_SEED, OpenMap.HASH_ROTATE) ^ y1) * OpenMap.HASH_SEED;
int keyHash2 = (Integer.rotateLeft(x2 * OpenMap.HASH_SEED, OpenMap.HASH_ROTATE) ^ y2) * OpenMap.HASH_SEED;
int entryIdx1 = (keyHash1 & OpenMap.ENTRY_HASH_MASK) << OpenMap.ENTRY_SIZE_SHIFT;
int entryIdx2 = (keyHash2 & OpenMap.ENTRY_HASH_MASK) << OpenMap.ENTRY_SIZE_SHIFT;
////////////////////////////////////////////////////////////////////////////////////////////////////////
// Put keys and calculate entry offsets to put values
////////////////////////////////////////////////////////////////////////////////////////////////////////
int entryOffset1 = map.putKey(keyVector1, keyStartPtr1, keyLength1, entryIdx1);
int entryOffset2 = map.putKey(keyVector2, keyStartPtr2, keyLength2, entryIdx2);
////////////////////////////////////////////////////////////////////////////////////////////////////////
// Extract values by parsing and put them into map
////////////////////////////////////////////////////////////////////////////////////////////////////////
regionPtr1 = extractValue(regionPtr1, word1, map, entryOffset1);
regionPtr2 = extractValue(regionPtr2, word2, map, entryOffset2);
////////////////////////////////////////////////////////////////////////////////////////////////////////
}
// Read and process region - tail
doProcessTail(regionPtr1, regionEnd1, regionPtr2, regionEnd2, vectorSize);
}
private void doProcessTail(long regionPtr1, long regionEnd1, long regionPtr2, long regionEnd2, int vectorSize) {
while (regionPtr1 < regionEnd1) {
long keyStartPtr1 = regionPtr1;
ByteVector keyVector1 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr1, NATIVE_BYTE_ORDER);
int keyLength1 = keyVector1.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
if (keyLength1 != vectorSize) {
regionPtr1 += (keyLength1 + 1);
}
else {
regionPtr1 += vectorSize;
for (; U.getByte(regionPtr1) != KEY_VALUE_SEPARATOR; regionPtr1++)
;
keyLength1 = (int) (regionPtr1 - keyStartPtr1);
regionPtr1++;
}
int entryIdx1 = map.calculateEntryIndex(keyStartPtr1, keyLength1);
int entryOffset1 = map.putKey(keyVector1, keyStartPtr1, keyLength1, entryIdx1);
long word1 = U.getLong(regionPtr1);
if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
word1 = Long.reverseBytes(word1);
}
regionPtr1 = extractValue(regionPtr1, word1, map, entryOffset1);
}
while (regionPtr2 < regionEnd2) {
long keyStartPtr2 = regionPtr2;
ByteVector keyVector2 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr2, NATIVE_BYTE_ORDER);
int keyLength2 = keyVector2.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
if (keyLength2 != vectorSize) {
regionPtr2 += (keyLength2 + 1);
}
else {
regionPtr2 += vectorSize;
for (; U.getByte(regionPtr2) != KEY_VALUE_SEPARATOR; regionPtr2++)
;
keyLength2 = (int) (regionPtr2 - keyStartPtr2);
regionPtr2++;
}
int entryIdx2 = map.calculateEntryIndex(keyStartPtr2, keyLength2);
int entryOffset2 = map.putKey(keyVector2, keyStartPtr2, keyLength2, entryIdx2);
long word2 = U.getLong(regionPtr2);
if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
word2 = Long.reverseBytes(word2);
}
regionPtr2 = extractValue(regionPtr2, word2, map, entryOffset2);
}
}
}
/**
* Region processor request
* Region processor task
*/
private static final class Request {
private final Arena arena;
private final Queue<Task> sharedTasks;
private final Result result;
private Request(Arena arena, Queue<Task> sharedTasks, Result result) {
this.arena = arena;
this.sharedTasks = sharedTasks;
this.result = result;
}
}
private static final class Task {
private final FileChannel fileChannel;
@ -451,6 +550,23 @@ public class CalculateAverage_serkan_ozal {
}
/**
* Region processor request
*/
private static final class Request {
private final Arena arena;
private final Queue<Task> sharedTasks;
private final Result result;
private Request(Arena arena, Queue<Task> sharedTasks, Result result) {
this.arena = arena;
this.sharedTasks = sharedTasks;
this.result = result;
}
}
/**
* Region processor response
*/
@ -555,6 +671,9 @@ public class CalculateAverage_serkan_ozal {
}
/**
* Custom map implementation to store results
*/
private static final class OpenMap {
// Layout
@ -585,21 +704,22 @@ public class CalculateAverage_serkan_ozal {
private static final int ENTRY_MASK = MAP_SIZE - 1;
private static final int KEY_ARRAY_OFFSET = KEY_OFFSET - Unsafe.ARRAY_BYTE_BASE_OFFSET;
private static final int HASH_SEED = 0x9E3779B9;
private static final int HASH_ROTATE = 5;
private final byte[] data;
private final long[] entryOffsets;
private final int[] entryOffsets;
private int entryOffsetIdx;
private OpenMap() {
this.data = new byte[MAP_SIZE];
// Max number of unique keys are 10K, so 1 << 14 (16384) is long enough to hold offsets for all of them
this.entryOffsets = new long[1 << 14];
this.entryOffsets = new int[1 << 14];
this.entryOffsetIdx = 0;
}
// Credits: merykitty
private static int calculateKeyHash(long address, int keyLength) {
int seed = 0x9E3779B9;
int rotate = 5;
private int calculateEntryIndex(long address, int keyLength) {
int x, y;
if (keyLength >= Integer.BYTES) {
x = U.getInt(address);
@ -609,19 +729,17 @@ public class CalculateAverage_serkan_ozal {
x = U.getByte(address);
y = U.getByte(address + keyLength - Byte.BYTES);
}
return (Integer.rotateLeft(x * seed, rotate) ^ y) * seed;
// Calculate key hash
int keyHash = (Integer.rotateLeft(x * HASH_SEED, HASH_ROTATE) ^ y) * HASH_SEED;
// Get the position of the entry in the linear map based on calculated hash
return (keyHash & ENTRY_HASH_MASK) << ENTRY_SIZE_SHIFT;
}
private long putKey(ByteVector keyVector, long keyStartAddress, int keyLength) {
// Calculate hash of key
int keyHash = calculateKeyHash(keyStartAddress, keyLength);
// and get the position of the entry in the linear map based on calculated hash
int idx = (keyHash & ENTRY_HASH_MASK) << ENTRY_SIZE_SHIFT;
private int putKey(ByteVector keyVector, long keyStartAddress, int keyLength, int entryIdx) {
// Start searching from the calculated position
// and continue until find an available slot in case of hash collision
// TODO Prevent infinite loop if all the slots are in use for other keys
for (long entryOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET + idx;; entryOffset = (entryOffset + ENTRY_SIZE) & ENTRY_MASK) {
for (int entryOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET + entryIdx;; entryOffset = (entryOffset + ENTRY_SIZE) & ENTRY_MASK) {
int keySize = U.getInt(data, entryOffset + KEY_SIZE_OFFSET);
// Check whether current index is empty (no another key is inserted yet)
if (keySize == 0) {
@ -633,33 +751,27 @@ public class CalculateAverage_serkan_ozal {
entryOffsets[entryOffsetIdx++] = entryOffset;
return entryOffset;
}
int keyStartArrayOffset = (int) entryOffset + KEY_ARRAY_OFFSET;
// Check for hash collision (hashes are same, but keys are different).
// If there is no collision (both hashes and keys are equals), return current slot's offset.
// Otherwise, continue iterating until find an available slot.
if (keySize == keyLength && keysEqual(keyVector, keyStartAddress, keyLength, keyStartArrayOffset)) {
if (keySize == keyLength && keysEqual(keyVector, keyStartAddress, keyLength, entryOffset + KEY_ARRAY_OFFSET)) {
return entryOffset;
}
}
}
private boolean keysEqual(ByteVector keyVector, long keyStartAddress, int keyLength, int keyStartArrayOffset) {
int keyCheckIdx = 0;
if (keyVector != null) {
// Use vectorized search for the comparison of keys.
// Since majority of the city names >= 8 bytes and <= 16 bytes,
// this way is more efficient (according to my experiments) than any other comparisons (byte by byte or 2 longs).
ByteVector entryKeyVector = ByteVector.fromArray(BYTE_SPECIES, data, keyStartArrayOffset);
long eqMask = keyVector.compare(VectorOperators.EQ, entryKeyVector).toLong();
int eqCount = Long.numberOfTrailingZeros(~eqMask);
if (eqCount >= keyLength) {
int eqCount = keyVector.compare(VectorOperators.EQ, entryKeyVector).trueCount();
if (eqCount == keyLength) {
return true;
}
else if (keyLength <= BYTE_SPECIES_SIZE) {
return false;
}
keyCheckIdx = BYTE_SPECIES_SIZE;
}
// Compare remaining parts of the keys
@ -671,7 +783,7 @@ public class CalculateAverage_serkan_ozal {
long keyStartOffset = keyStartArrayOffset + Unsafe.ARRAY_BYTE_BASE_OFFSET;
int alignedKeyLength = normalizedKeyLength & 0xFFFFFFF8;
int i;
for (i = keyCheckIdx; i < alignedKeyLength; i += Long.BYTES) {
for (i = BYTE_SPECIES_SIZE; i < alignedKeyLength; i += Long.BYTES) {
if (U.getLong(keyStartAddress + i) != U.getLong(data, keyStartOffset + i)) {
return false;
}
@ -690,18 +802,18 @@ public class CalculateAverage_serkan_ozal {
return wordA == wordB;
}
private void putValue(long entryOffset, int value) {
long countOffset = entryOffset + COUNT_OFFSET;
private void putValue(int entryOffset, int value) {
int countOffset = entryOffset + COUNT_OFFSET;
U.putInt(data, countOffset, U.getInt(data, countOffset) + 1);
long minValueOffset = entryOffset + MIN_VALUE_OFFSET;
int minValueOffset = entryOffset + MIN_VALUE_OFFSET;
if (value < U.getShort(data, minValueOffset)) {
U.putShort(data, minValueOffset, (short) value);
}
long maxValueOffset = entryOffset + MAX_VALUE_OFFSET;
int maxValueOffset = entryOffset + MAX_VALUE_OFFSET;
if (value > U.getShort(data, maxValueOffset)) {
U.putShort(data, maxValueOffset, (short) value);
}
long sumOffset = entryOffset + VALUE_SUM_OFFSET;
int sumOffset = entryOffset + VALUE_SUM_OFFSET;
U.putLong(data, sumOffset, U.getLong(data, sumOffset) + value);
}
@ -709,13 +821,13 @@ public class CalculateAverage_serkan_ozal {
// Merge this local map into global result map
Arrays.sort(entryOffsets, 0, entryOffsetIdx);
for (int i = 0; i < entryOffsetIdx; i++) {
long entryOffset = entryOffsets[i];
int entryOffset = entryOffsets[i];
int keyLength = U.getInt(data, entryOffset + KEY_SIZE_OFFSET);
if (keyLength == 0) {
// No entry is available for this index, so continue iterating
continue;
}
int entryArrayIdx = (int) (entryOffset + KEY_OFFSET - Unsafe.ARRAY_BYTE_BASE_OFFSET);
int entryArrayIdx = entryOffset + KEY_OFFSET - Unsafe.ARRAY_BYTE_BASE_OFFSET;
String key = new String(data, entryArrayIdx, keyLength, StandardCharsets.UTF_8);
int count = U.getInt(data, entryOffset + COUNT_OFFSET);
short minValue = U.getShort(data, entryOffset + MIN_VALUE_OFFSET);