One last improvement for thomaswue (#702)
* Combine <8 and 8-16 cases into one case. * Adopt mask-based approach for the <16 length city fast path (idea of Van Phu Do). * Slightly improved code layout. * Update perf number.
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							4debc7c5dd
						
					
				
				
					commit
					241d42ca66
				
			| @@ -27,11 +27,14 @@ import java.util.concurrent.atomic.AtomicLong; | ||||
|  * split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread. | ||||
|  * Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in | ||||
|  * the end. | ||||
|  * Runs in 0.39s on an Intel i9-13900K. | ||||
|  * Runs in 0.31 on an Intel i9-13900K while the reference implementation takes 120.37s. | ||||
|  * Credit: | ||||
|  *  Quan Anh Mai for branchless number parsing code | ||||
|  *  Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea | ||||
|  *  Artsiom Korzun for showing the benefits of work stealing at 2MB segments instead of equal split between workers | ||||
|  *  Jaromir Hamala for showing that avoiding the branch misprediction between <8 and 8-16 cases is a big win even if | ||||
|  *  more work is performed | ||||
|  *  Van Phu DO for demonstrating the lookup tables based on masks instead of bit shifting | ||||
|  */ | ||||
| public class CalculateAverage_thomaswue { | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
| @@ -141,9 +144,15 @@ public class CalculateAverage_thomaswue { | ||||
|                 long delimiterMask1 = findDelimiter(word1); | ||||
|                 long delimiterMask2 = findDelimiter(word2); | ||||
|                 long delimiterMask3 = findDelimiter(word3); | ||||
|                 Result existingResult1 = findResult(word1, delimiterMask1, scanner1, results, collectedResults); | ||||
|                 Result existingResult2 = findResult(word2, delimiterMask2, scanner2, results, collectedResults); | ||||
|                 Result existingResult3 = findResult(word3, delimiterMask3, scanner3, results, collectedResults); | ||||
|                 long word1b = scanner1.getLongAt(scanner1.pos() + 8); | ||||
|                 long word2b = scanner2.getLongAt(scanner2.pos() + 8); | ||||
|                 long word3b = scanner3.getLongAt(scanner3.pos() + 8); | ||||
|                 long delimiterMask1b = findDelimiter(word1b); | ||||
|                 long delimiterMask2b = findDelimiter(word2b); | ||||
|                 long delimiterMask3b = findDelimiter(word3b); | ||||
|                 Result existingResult1 = findResult(word1, delimiterMask1, word1b, delimiterMask1b, scanner1, results, collectedResults); | ||||
|                 Result existingResult2 = findResult(word2, delimiterMask2, word2b, delimiterMask2b, scanner2, results, collectedResults); | ||||
|                 Result existingResult3 = findResult(word3, delimiterMask3, word3b, delimiterMask3b, scanner3, results, collectedResults); | ||||
|                 long number1 = scanNumber(scanner1); | ||||
|                 long number2 = scanNumber(scanner2); | ||||
|                 long number3 = scanNumber(scanner3); | ||||
| @@ -155,76 +164,70 @@ public class CalculateAverage_thomaswue { | ||||
|             while (scanner1.hasNext()) { | ||||
|                 long word = scanner1.getLong(); | ||||
|                 long pos = findDelimiter(word); | ||||
|                 record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1)); | ||||
|                 long wordB = scanner1.getLongAt(scanner1.pos() + 8); | ||||
|                 long posB = findDelimiter(wordB); | ||||
|                 record(findResult(word, pos, wordB, posB, scanner1, results, collectedResults), scanNumber(scanner1)); | ||||
|             } | ||||
|             while (scanner2.hasNext()) { | ||||
|                 long word = scanner2.getLong(); | ||||
|                 long pos = findDelimiter(word); | ||||
|                 record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2)); | ||||
|                 long wordB = scanner2.getLongAt(scanner2.pos() + 8); | ||||
|                 long posB = findDelimiter(wordB); | ||||
|                 record(findResult(word, pos, wordB, posB, scanner2, results, collectedResults), scanNumber(scanner2)); | ||||
|             } | ||||
|             while (scanner3.hasNext()) { | ||||
|                 long word = scanner3.getLong(); | ||||
|                 long pos = findDelimiter(word); | ||||
|                 record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3)); | ||||
|                 long wordB = scanner3.getLongAt(scanner3.pos() + 8); | ||||
|                 long posB = findDelimiter(wordB); | ||||
|                 record(findResult(word, pos, wordB, posB, scanner3, results, collectedResults), scanNumber(scanner3)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static Result findResult(long initialWord, long initialDelimiterMask, Scanner scanner, Result[] results, List<Result> collectedResults) { | ||||
|     private static final long[] MASK1 = new long[]{ 0xFFL, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL, 0xFFFFFFFFFFL, 0xFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL, | ||||
|             0xFFFFFFFFFFFFFFFFL }; | ||||
|     private static final long[] MASK2 = new long[]{ 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0xFFFFFFFFFFFFFFFFL }; | ||||
|  | ||||
|     private static Result findResult(long initialWord, long initialDelimiterMask, long wordB, long delimiterMaskB, Scanner scanner, Result[] results, | ||||
|                                      List<Result> collectedResults) { | ||||
|         Result existingResult; | ||||
|         long word = initialWord; | ||||
|         long delimiterMask = initialDelimiterMask; | ||||
|         long hash; | ||||
|         long nameAddress = scanner.pos(); | ||||
|  | ||||
|         // Search for ';', one long at a time. There are two common cases that a specially treated: | ||||
|         // (b) the ';' is found in the first 16 bytes | ||||
|         if (delimiterMask != 0) { | ||||
|             // Special case for when the ';' is found in the first 8 bytes. | ||||
|             int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); | ||||
|             word = (word << (63 - trailingZeros)); | ||||
|             scanner.add(trailingZeros >>> 3); | ||||
|             hash = word; | ||||
|         long word2 = wordB; | ||||
|         long delimiterMask2 = delimiterMaskB; | ||||
|         if ((delimiterMask | delimiterMask2) != 0) { | ||||
|             int letterCount1 = Long.numberOfTrailingZeros(delimiterMask) >>> 3; // value between 1 and 8 | ||||
|             int letterCount2 = Long.numberOfTrailingZeros(delimiterMask2) >>> 3; // value between 0 and 8 | ||||
|             long mask = MASK2[letterCount1]; | ||||
|             word = word & MASK1[letterCount1]; | ||||
|             word2 = mask & word2 & MASK1[letterCount2]; | ||||
|             hash = word ^ word2; | ||||
|             existingResult = results[hashToIndex(hash, results)]; | ||||
|             if (existingResult != null && existingResult.lastNameLong == word) { | ||||
|             scanner.add(letterCount1 + (letterCount2 & mask)); | ||||
|             if (existingResult != null && existingResult.firstNameWord == word && existingResult.secondNameWord == word2) { | ||||
|                 return existingResult; | ||||
|             } | ||||
|         } | ||||
|         else { | ||||
|             // Special case for when the ';' is found in bytes 9-16. | ||||
|             hash = word; | ||||
|             long prevWord = word; | ||||
|             scanner.add(8); | ||||
|             word = scanner.getLong(); | ||||
|             delimiterMask = findDelimiter(word); | ||||
|             if (delimiterMask != 0) { | ||||
|                 int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); | ||||
|                 word = (word << (63 - trailingZeros)); | ||||
|                 scanner.add(trailingZeros >>> 3); | ||||
|                 hash ^= word; | ||||
|                 existingResult = results[hashToIndex(hash, results)]; | ||||
|                 if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { | ||||
|                     return existingResult; | ||||
|             // Slow-path for when the ';' could not be found in the first 16 bytes. | ||||
|             hash = word ^ word2; | ||||
|             scanner.add(16); | ||||
|             while (true) { | ||||
|                 word = scanner.getLong(); | ||||
|                 delimiterMask = findDelimiter(word); | ||||
|                 if (delimiterMask != 0) { | ||||
|                     int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); | ||||
|                     word = (word << (63 - trailingZeros)); | ||||
|                     scanner.add(trailingZeros >>> 3); | ||||
|                     hash ^= word; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             else { | ||||
|                 // Slow-path for when the ';' could not be found in the first 16 bytes. | ||||
|                 scanner.add(8); | ||||
|                 hash ^= word; | ||||
|                 while (true) { | ||||
|                     word = scanner.getLong(); | ||||
|                     delimiterMask = findDelimiter(word); | ||||
|                     if (delimiterMask != 0) { | ||||
|                         int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); | ||||
|                         word = (word << (63 - trailingZeros)); | ||||
|                         scanner.add(trailingZeros >>> 3); | ||||
|                         hash ^= word; | ||||
|                         break; | ||||
|                     } | ||||
|                     else { | ||||
|                         scanner.add(8); | ||||
|                         hash ^= word; | ||||
|                     } | ||||
|                 else { | ||||
|                     scanner.add(8); | ||||
|                     hash ^= word; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| @@ -249,8 +252,8 @@ public class CalculateAverage_thomaswue { | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             int remainingShift = (64 - (nameLength + 1 - i) << 3); | ||||
|             if (existingResult.lastNameLong == (scanner.getLongAt(nameAddress + i) << remainingShift)) { | ||||
|             int remainingShift = (64 - ((nameLength + 1 - i) << 3)); | ||||
|             if (((scanner.getLongAt(existingResult.nameAddress + i) ^ (scanner.getLongAt(nameAddress + i))) << remainingShift) == 0) { | ||||
|                 break; | ||||
|             } | ||||
|             else { | ||||
| @@ -297,7 +300,7 @@ public class CalculateAverage_thomaswue { | ||||
|     } | ||||
|  | ||||
|     private static int hashToIndex(long hash, Result[] results) { | ||||
|         long hashAsInt = hash ^ (hash >>> 37) ^ (hash >>> 17); | ||||
|         long hashAsInt = hash ^ (hash >>> 33) ^ (hash >>> 15); | ||||
|         return (int) (hashAsInt & (results.length - 1)); | ||||
|     } | ||||
|  | ||||
| @@ -324,21 +327,23 @@ public class CalculateAverage_thomaswue { | ||||
|     private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List<Result> collectedResults) { | ||||
|         Result r = new Result(); | ||||
|         results[hash] = r; | ||||
|         int i = 0; | ||||
|         for (; i < nameLength + 1 - Long.BYTES; i += Long.BYTES) { | ||||
|         int totalLength = nameLength + 1; | ||||
|         r.firstNameWord = scanner.getLongAt(nameAddress); | ||||
|         r.secondNameWord = scanner.getLongAt(nameAddress + 8); | ||||
|         if (totalLength <= 8) { | ||||
|             r.firstNameWord = r.firstNameWord & MASK1[totalLength - 1]; | ||||
|             r.secondNameWord = 0; | ||||
|         } | ||||
|         if (nameLength + 1 > 8) { | ||||
|             r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8); | ||||
|         else if (totalLength < 16) { | ||||
|             r.secondNameWord = r.secondNameWord & MASK1[totalLength - 9]; | ||||
|         } | ||||
|         int remainingShift = (64 - (nameLength + 1 - i) << 3); | ||||
|         r.lastNameLong = (scanner.getLongAt(nameAddress + i) << remainingShift); | ||||
|         r.nameAddress = nameAddress; | ||||
|         collectedResults.add(r); | ||||
|         return r; | ||||
|     } | ||||
|  | ||||
|     private static final class Result { | ||||
|         long lastNameLong, secondLastNameLong; | ||||
|         long firstNameWord, secondNameWord; | ||||
|         short min, max; | ||||
|         int count; | ||||
|         long sum; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user