One last improvement for thomaswue (#702)

* Combine <8 and 8-16 cases into one case.

* Adopt mask-based approach for the <16 length city fast path (idea of Van Phu Do).

* Slightly improved code layout.

* Update perf number.
This commit is contained in:
Thomas Wuerthinger 2024-02-01 10:57:05 +01:00 committed by GitHub
parent 4debc7c5dd
commit 241d42ca66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,11 +27,14 @@ import java.util.concurrent.atomic.AtomicLong;
* split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread.
* Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in
* the end.
* Runs in 0.39s on an Intel i9-13900K.
* Runs in 0.31 on an Intel i9-13900K while the reference implementation takes 120.37s.
* Credit:
* Quan Anh Mai for branchless number parsing code
* Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea
* Artsiom Korzun for showing the benefits of work stealing at 2MB segments instead of equal split between workers
* Jaromir Hamala for showing that avoiding the branch misprediction between <8 and 8-16 cases is a big win even if
* more work is performed
* Van Phu DO for demonstrating the lookup tables based on masks instead of bit shifting
*/
public class CalculateAverage_thomaswue {
private static final String FILE = "./measurements.txt";
@ -141,9 +144,15 @@ public class CalculateAverage_thomaswue {
long delimiterMask1 = findDelimiter(word1);
long delimiterMask2 = findDelimiter(word2);
long delimiterMask3 = findDelimiter(word3);
Result existingResult1 = findResult(word1, delimiterMask1, scanner1, results, collectedResults);
Result existingResult2 = findResult(word2, delimiterMask2, scanner2, results, collectedResults);
Result existingResult3 = findResult(word3, delimiterMask3, scanner3, results, collectedResults);
long word1b = scanner1.getLongAt(scanner1.pos() + 8);
long word2b = scanner2.getLongAt(scanner2.pos() + 8);
long word3b = scanner3.getLongAt(scanner3.pos() + 8);
long delimiterMask1b = findDelimiter(word1b);
long delimiterMask2b = findDelimiter(word2b);
long delimiterMask3b = findDelimiter(word3b);
Result existingResult1 = findResult(word1, delimiterMask1, word1b, delimiterMask1b, scanner1, results, collectedResults);
Result existingResult2 = findResult(word2, delimiterMask2, word2b, delimiterMask2b, scanner2, results, collectedResults);
Result existingResult3 = findResult(word3, delimiterMask3, word3b, delimiterMask3b, scanner3, results, collectedResults);
long number1 = scanNumber(scanner1);
long number2 = scanNumber(scanner2);
long number3 = scanNumber(scanner3);
@ -155,76 +164,70 @@ public class CalculateAverage_thomaswue {
while (scanner1.hasNext()) {
long word = scanner1.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1));
long wordB = scanner1.getLongAt(scanner1.pos() + 8);
long posB = findDelimiter(wordB);
record(findResult(word, pos, wordB, posB, scanner1, results, collectedResults), scanNumber(scanner1));
}
while (scanner2.hasNext()) {
long word = scanner2.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2));
long wordB = scanner2.getLongAt(scanner2.pos() + 8);
long posB = findDelimiter(wordB);
record(findResult(word, pos, wordB, posB, scanner2, results, collectedResults), scanNumber(scanner2));
}
while (scanner3.hasNext()) {
long word = scanner3.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3));
long wordB = scanner3.getLongAt(scanner3.pos() + 8);
long posB = findDelimiter(wordB);
record(findResult(word, pos, wordB, posB, scanner3, results, collectedResults), scanNumber(scanner3));
}
}
}
private static Result findResult(long initialWord, long initialDelimiterMask, Scanner scanner, Result[] results, List<Result> collectedResults) {
private static final long[] MASK1 = new long[]{ 0xFFL, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL, 0xFFFFFFFFFFL, 0xFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL,
0xFFFFFFFFFFFFFFFFL };
private static final long[] MASK2 = new long[]{ 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0xFFFFFFFFFFFFFFFFL };
private static Result findResult(long initialWord, long initialDelimiterMask, long wordB, long delimiterMaskB, Scanner scanner, Result[] results,
List<Result> collectedResults) {
Result existingResult;
long word = initialWord;
long delimiterMask = initialDelimiterMask;
long hash;
long nameAddress = scanner.pos();
// Search for ';', one long at a time. There are two common cases that a specially treated:
// (b) the ';' is found in the first 16 bytes
if (delimiterMask != 0) {
// Special case for when the ';' is found in the first 8 bytes.
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash = word;
long word2 = wordB;
long delimiterMask2 = delimiterMaskB;
if ((delimiterMask | delimiterMask2) != 0) {
int letterCount1 = Long.numberOfTrailingZeros(delimiterMask) >>> 3; // value between 1 and 8
int letterCount2 = Long.numberOfTrailingZeros(delimiterMask2) >>> 3; // value between 0 and 8
long mask = MASK2[letterCount1];
word = word & MASK1[letterCount1];
word2 = mask & word2 & MASK1[letterCount2];
hash = word ^ word2;
existingResult = results[hashToIndex(hash, results)];
if (existingResult != null && existingResult.lastNameLong == word) {
scanner.add(letterCount1 + (letterCount2 & mask));
if (existingResult != null && existingResult.firstNameWord == word && existingResult.secondNameWord == word2) {
return existingResult;
}
}
else {
// Special case for when the ';' is found in bytes 9-16.
hash = word;
long prevWord = word;
scanner.add(8);
word = scanner.getLong();
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
existingResult = results[hashToIndex(hash, results)];
if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) {
return existingResult;
// Slow-path for when the ';' could not be found in the first 16 bytes.
hash = word ^ word2;
scanner.add(16);
while (true) {
word = scanner.getLong();
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
break;
}
}
else {
// Slow-path for when the ';' could not be found in the first 16 bytes.
scanner.add(8);
hash ^= word;
while (true) {
word = scanner.getLong();
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
break;
}
else {
scanner.add(8);
hash ^= word;
}
else {
scanner.add(8);
hash ^= word;
}
}
}
@ -249,8 +252,8 @@ public class CalculateAverage_thomaswue {
}
}
int remainingShift = (64 - (nameLength + 1 - i) << 3);
if (existingResult.lastNameLong == (scanner.getLongAt(nameAddress + i) << remainingShift)) {
int remainingShift = (64 - ((nameLength + 1 - i) << 3));
if (((scanner.getLongAt(existingResult.nameAddress + i) ^ (scanner.getLongAt(nameAddress + i))) << remainingShift) == 0) {
break;
}
else {
@ -297,7 +300,7 @@ public class CalculateAverage_thomaswue {
}
private static int hashToIndex(long hash, Result[] results) {
long hashAsInt = hash ^ (hash >>> 37) ^ (hash >>> 17);
long hashAsInt = hash ^ (hash >>> 33) ^ (hash >>> 15);
return (int) (hashAsInt & (results.length - 1));
}
@ -324,21 +327,23 @@ public class CalculateAverage_thomaswue {
private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List<Result> collectedResults) {
Result r = new Result();
results[hash] = r;
int i = 0;
for (; i < nameLength + 1 - Long.BYTES; i += Long.BYTES) {
int totalLength = nameLength + 1;
r.firstNameWord = scanner.getLongAt(nameAddress);
r.secondNameWord = scanner.getLongAt(nameAddress + 8);
if (totalLength <= 8) {
r.firstNameWord = r.firstNameWord & MASK1[totalLength - 1];
r.secondNameWord = 0;
}
if (nameLength + 1 > 8) {
r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8);
else if (totalLength < 16) {
r.secondNameWord = r.secondNameWord & MASK1[totalLength - 9];
}
int remainingShift = (64 - (nameLength + 1 - i) << 3);
r.lastNameLong = (scanner.getLongAt(nameAddress + i) << remainingShift);
r.nameAddress = nameAddress;
collectedResults.add(r);
return r;
}
private static final class Result {
long lastNameLong, secondLastNameLong;
long firstNameWord, secondNameWord;
short min, max;
int count;
long sum;