Adding Scanner object and also tuning for better branch prediction for about +6%. (#341)
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							dac38bc97f
						
					
				
				
					commit
					bd4cff945d
				
			| @@ -32,30 +32,25 @@ import java.util.stream.IntStream; | ||||
|  * Simple solution that memory maps the input file, then splits it into one segment per available core and uses | ||||
|  * sun.misc.Unsafe to directly access the mapped memory. Uses a long at a time when checking for collision. | ||||
|  * <p> | ||||
|  * Runs in 0.70s on my Intel i9-13900K | ||||
|  * Runs in 0.66s on my Intel i9-13900K | ||||
|  * Perf stats: | ||||
|  *     40,622,862,783      cpu_core/cycles/ | ||||
|  *     48,241,929,925      cpu_atom/cycles/ | ||||
|  *     35,935,262,091      cpu_core/cycles/ | ||||
|  *     47,305,591,173      cpu_atom/cycles/ | ||||
|  */ | ||||
| public class CalculateAverage_thomaswue { | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|  | ||||
|     // Holding the current result for a single city. | ||||
|     private static class Result { | ||||
|         final long nameAddress; | ||||
|         long lastNameLong; | ||||
|         int remainingShift; | ||||
|         int min; | ||||
|         int max; | ||||
|         long lastNameLong, secondLastNameLong, nameAddress; | ||||
|         int nameLength, remainingShift; | ||||
|         int min, max, count; | ||||
|         long sum; | ||||
|         int count; | ||||
|  | ||||
|         private Result(long nameAddress, int value) { | ||||
|         private Result(long nameAddress) { | ||||
|             this.nameAddress = nameAddress; | ||||
|             this.min = value; | ||||
|             this.max = value; | ||||
|             this.sum = value; | ||||
|             this.count = 1; | ||||
|             this.min = Integer.MAX_VALUE; | ||||
|             this.max = Integer.MIN_VALUE; | ||||
|         } | ||||
|  | ||||
|         public String toString() { | ||||
| @@ -73,6 +68,10 @@ public class CalculateAverage_thomaswue { | ||||
|             sum += other.sum; | ||||
|             count += other.count; | ||||
|         } | ||||
|  | ||||
|         public String calcName() { | ||||
|             return new Scanner(nameAddress, nameAddress + nameLength).getString(nameLength); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     public static void main(String[] args) throws IOException { | ||||
| @@ -81,122 +80,155 @@ public class CalculateAverage_thomaswue { | ||||
|         long[] chunks = getSegments(numberOfChunks); | ||||
|  | ||||
|         // Parallel processing of segments. | ||||
|         List<HashMap<String, Result>> allResults = IntStream.range(0, chunks.length - 1).mapToObj(chunkIndex -> { | ||||
|             HashMap<String, Result> cities = HashMap.newHashMap(1 << 10); | ||||
|             parseLoop(chunks[chunkIndex], chunks[chunkIndex + 1], cities); | ||||
|             return cities; | ||||
|         }).parallel().toList(); | ||||
|         List<List<Result>> allResults = IntStream.range(0, chunks.length - 1).mapToObj(chunkIndex -> parseLoop(chunks[chunkIndex], chunks[chunkIndex + 1])) | ||||
|                 .map(resultArray -> { | ||||
|                     List<Result> results = new ArrayList<>(); | ||||
|                     for (Result r : resultArray) { | ||||
|                         if (r != null) { | ||||
|                             results.add(r); | ||||
|                         } | ||||
|                     } | ||||
|                     return results; | ||||
|                 }).parallel().toList(); | ||||
|  | ||||
|         // Accumulate results sequentially. | ||||
|         HashMap<String, Result> result = allResults.getFirst(); | ||||
|         for (int i = 1; i < allResults.size(); ++i) { | ||||
|             for (Map.Entry<String, Result> entry : allResults.get(i).entrySet()) { | ||||
|                 Result current = result.putIfAbsent(entry.getKey(), entry.getValue()); | ||||
|         // Final output. | ||||
|         System.out.println(accumulateResults(allResults)); | ||||
|     } | ||||
|  | ||||
|     // Accumulate results sequentially for simplicity. | ||||
|     private static TreeMap<String, Result> accumulateResults(List<List<Result>> allResults) { | ||||
|         TreeMap<String, Result> result = new TreeMap<>(); | ||||
|         for (List<Result> resultArr : allResults) { | ||||
|             for (Result r : resultArr) { | ||||
|                 String name = r.calcName(); | ||||
|                 Result current = result.putIfAbsent(name, r); | ||||
|                 if (current != null) { | ||||
|                     current.add(entry.getValue()); | ||||
|                     current.add(r); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // Final output. | ||||
|         System.out.println(new TreeMap<>(result)); | ||||
|         return result; | ||||
|     } | ||||
|  | ||||
|     private static final Unsafe UNSAFE = initUnsafe(); | ||||
|  | ||||
|     private static Unsafe initUnsafe() { | ||||
|         try { | ||||
|             Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); | ||||
|             theUnsafe.setAccessible(true); | ||||
|             return (Unsafe) theUnsafe.get(Unsafe.class); | ||||
|         } | ||||
|         catch (NoSuchFieldException | IllegalAccessException e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static void parseLoop(long chunkStart, long chunkEnd, HashMap<String, Result> cities) { | ||||
|     // Main parse loop. | ||||
|     private static Result[] parseLoop(long chunkStart, long chunkEnd) { | ||||
|         Result[] results = new Result[1 << 18]; | ||||
|         long scanPtr = chunkStart; | ||||
|         while (scanPtr < chunkEnd) { | ||||
|             long nameAddress = scanPtr; | ||||
|         Scanner scanner = new Scanner(chunkStart, chunkEnd); | ||||
|         while (scanner.hasNext()) { | ||||
|             long nameAddress = scanner.pos(); | ||||
|             long hash = 0; | ||||
|  | ||||
|             // Search for ';', one long at a time. | ||||
|             long word = UNSAFE.getLong(scanPtr); | ||||
|             long word = scanner.getLong(); | ||||
|             int pos = findDelimiter(word); | ||||
|             if (pos != 8) { | ||||
|                 scanPtr += pos; | ||||
|                 word = word & (-1L >>> ((8 - pos - 1) << 3)); | ||||
|                 scanner.add(pos); | ||||
|                 word = mask(word, pos); | ||||
|                 hash ^= word; | ||||
|  | ||||
|                 Result existingResult = results[hashToIndex(hash, results)]; | ||||
|                 if (existingResult != null && existingResult.lastNameLong == word) { | ||||
|                     scanAndRecord(scanner, existingResult); | ||||
|                     continue; | ||||
|                 } | ||||
|             } | ||||
|             else { | ||||
|                 scanPtr += 8; | ||||
|                 scanner.add(8); | ||||
|                 hash ^= word; | ||||
|                 while (true) { | ||||
|                     word = UNSAFE.getLong(scanPtr); | ||||
|                     pos = findDelimiter(word); | ||||
|                     if (pos != 8) { | ||||
|                         scanPtr += pos; | ||||
|                         word = word & (-1L >>> ((8 - pos - 1) << 3)); | ||||
|                         hash ^= word; | ||||
|                         break; | ||||
|                 long prevWord = word; | ||||
|                 word = scanner.getLong(); | ||||
|                 pos = findDelimiter(word); | ||||
|                 if (pos != 8) { | ||||
|                     scanner.add(pos); | ||||
|                     word = mask(word, pos); | ||||
|                     hash ^= word; | ||||
|                     Result existingResult = results[hashToIndex(hash, results)]; | ||||
|                     if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { | ||||
|                         scanAndRecord(scanner, existingResult); | ||||
|                         continue; | ||||
|                     } | ||||
|                     else { | ||||
|                         scanPtr += 8; | ||||
|                         hash ^= word; | ||||
|                 } | ||||
|                 else { | ||||
|                     scanner.add(8); | ||||
|                     hash ^= word; | ||||
|                     while (true) { | ||||
|                         word = scanner.getLong(); | ||||
|                         pos = findDelimiter(word); | ||||
|                         if (pos != 8) { | ||||
|                             scanner.add(pos); | ||||
|                             word = mask(word, pos); | ||||
|                             hash ^= word; | ||||
|                             break; | ||||
|                         } | ||||
|                         else { | ||||
|                             scanner.add(8); | ||||
|                             hash ^= word; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // Save length of name for later. | ||||
|             int nameLength = (int) (scanPtr - nameAddress); | ||||
|             scanPtr++; | ||||
|             int nameLength = (int) (scanner.pos() - nameAddress); | ||||
|             scanner.add(1); | ||||
|  | ||||
|             long numberWord = UNSAFE.getLong(scanPtr); | ||||
|             // The 4th binary digit of the ascii of a digit is 1 while | ||||
|             // that of the '.' is 0. This finds the decimal separator | ||||
|             // The value can be 12, 20, 28 | ||||
|             long numberWord = scanner.getLong(); | ||||
|             int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); | ||||
|             int number = convertIntoNumber(decimalSepPos, numberWord); | ||||
|  | ||||
|             // Skip past new line. | ||||
|             // scanPtr++; | ||||
|             scanPtr += (decimalSepPos >>> 3) + 3; | ||||
|             scanner.add((decimalSepPos >>> 3) + 3); | ||||
|  | ||||
|             // Final calculation for index into hash table. | ||||
|             int hashAsInt = (int) (hash ^ (hash >>> 32)); | ||||
|             int finalHash = (hashAsInt ^ (hashAsInt >>> 18)); | ||||
|             int tableIndex = (finalHash & (results.length - 1)); | ||||
|             int tableIndex = hashToIndex(hash, results); | ||||
|             outer: while (true) { | ||||
|                 Result existingResult = results[tableIndex]; | ||||
|                 if (existingResult == null) { | ||||
|                     newEntry(results, cities, nameAddress, number, tableIndex, nameLength); | ||||
|                     existingResult = newEntry(results, nameAddress, tableIndex, nameLength, scanner); | ||||
|                 } | ||||
|                 // Check for collision. | ||||
|                 int i = 0; | ||||
|                 for (; i < nameLength + 1 - 8; i += 8) { | ||||
|                     if (scanner.getLongAt(existingResult.nameAddress + i) != scanner.getLongAt(nameAddress + i)) { | ||||
|                         tableIndex = (tableIndex + 1) & (results.length - 1); | ||||
|                         continue outer; | ||||
|                     } | ||||
|                 } | ||||
|                 if (((existingResult.lastNameLong ^ scanner.getLongAt(nameAddress + i)) << existingResult.remainingShift) == 0) { | ||||
|                     record(existingResult, number); | ||||
|                     break; | ||||
|                 } | ||||
|                 else { | ||||
|                     // Check for collision. | ||||
|                     int i = 0; | ||||
|                     for (; i < nameLength + 1 - 8; i += 8) { | ||||
|                         if (UNSAFE.getLong(existingResult.nameAddress + i) != UNSAFE.getLong(nameAddress + i)) { | ||||
|                             tableIndex = (tableIndex + 1) & (results.length - 1); | ||||
|                             continue outer; | ||||
|                         } | ||||
|                     } | ||||
|                     if (((existingResult.lastNameLong ^ UNSAFE.getLong(nameAddress + i)) << existingResult.remainingShift) == 0) { | ||||
|                         existingResult.min = Math.min(existingResult.min, number); | ||||
|                         existingResult.max = Math.max(existingResult.max, number); | ||||
|                         existingResult.sum += number; | ||||
|                         existingResult.count++; | ||||
|                         break; | ||||
|                     } | ||||
|                     else { | ||||
|                         // Collision error, try next. | ||||
|                         tableIndex = (tableIndex + 1) & (results.length - 1); | ||||
|                     } | ||||
|                     // Collision error, try next. | ||||
|                     tableIndex = (tableIndex + 1) & (results.length - 1); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         return results; | ||||
|     } | ||||
|  | ||||
|     private static void scanAndRecord(Scanner scanPtr, Result existingResult) { | ||||
|         scanPtr.add(1); | ||||
|         long numberWord = scanPtr.getLong(); | ||||
|         int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); | ||||
|         int number = convertIntoNumber(decimalSepPos, numberWord); | ||||
|         scanPtr.add((decimalSepPos >>> 3) + 3); | ||||
|         record(existingResult, number); | ||||
|     } | ||||
|  | ||||
|     private static void record(Result existingResult, int number) { | ||||
|         existingResult.min = Math.min(existingResult.min, number); | ||||
|         existingResult.max = Math.max(existingResult.max, number); | ||||
|         existingResult.sum += number; | ||||
|         existingResult.count++; | ||||
|     } | ||||
|  | ||||
|     private static int hashToIndex(long hash, Result[] results) { | ||||
|         int hashAsInt = (int) (hash ^ (hash >>> 32)); | ||||
|         int finalHash = (hashAsInt ^ (hashAsInt >>> 18)); | ||||
|         return (finalHash & (results.length - 1)); | ||||
|     } | ||||
|  | ||||
|     private static long mask(long word, int pos) { | ||||
|         return word & (-1L >>> ((8 - pos - 1) << 3)); | ||||
|     } | ||||
|  | ||||
|     // Special method to convert a number in the specific format into an int value without branches created by | ||||
| @@ -229,19 +261,18 @@ public class CalculateAverage_thomaswue { | ||||
|         return Long.numberOfTrailingZeros(tmp) >>> 3; | ||||
|     } | ||||
|  | ||||
|     private static void newEntry(Result[] results, HashMap<String, Result> cities, long nameAddress, int number, int hash, int nameLength) { | ||||
|         Result r = new Result(nameAddress, number); | ||||
|     private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner) { | ||||
|         Result r = new Result(nameAddress); | ||||
|         results[hash] = r; | ||||
|         byte[] bytes = new byte[nameLength]; | ||||
|  | ||||
|         int i = 0; | ||||
|         for (; i < nameLength + 1 - 8; i += 8) { | ||||
|             r.secondLastNameLong = (scanner.getLongAt(nameAddress + i)); | ||||
|         } | ||||
|         r.lastNameLong = UNSAFE.getLong(nameAddress + i); | ||||
|         r.remainingShift = (64 - (nameLength + 1 - i) << 3); | ||||
|         UNSAFE.copyMemory(null, nameAddress, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); | ||||
|         String nameAsString = new String(bytes, StandardCharsets.UTF_8); | ||||
|         cities.put(nameAsString, r); | ||||
|         r.lastNameLong = (scanner.getLongAt(nameAddress + i) & (-1L >>> r.remainingShift)); | ||||
|         r.nameLength = nameLength; | ||||
|         return r; | ||||
|     } | ||||
|  | ||||
|     private static long[] getSegments(int numberOfChunks) throws IOException { | ||||
| @@ -252,10 +283,11 @@ public class CalculateAverage_thomaswue { | ||||
|             long mappedAddress = fileChannel.map(MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); | ||||
|             chunks[0] = mappedAddress; | ||||
|             long endAddress = mappedAddress + fileSize; | ||||
|             Scanner s = new Scanner(mappedAddress, mappedAddress + fileSize); | ||||
|             for (int i = 1; i < numberOfChunks; ++i) { | ||||
|                 long chunkAddress = mappedAddress + i * segmentSize; | ||||
|                 // Align to first row start. | ||||
|                 while (chunkAddress < endAddress && UNSAFE.getByte(chunkAddress++) != '\n') { | ||||
|                 while (chunkAddress < endAddress && (s.getLongAt(chunkAddress++) & 0xFF) != '\n') { | ||||
|                     // nop | ||||
|                 } | ||||
|                 chunks[i] = Math.min(chunkAddress, endAddress); | ||||
| @@ -264,4 +296,53 @@ public class CalculateAverage_thomaswue { | ||||
|             return chunks; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static class Scanner { | ||||
|  | ||||
|         private static final Unsafe UNSAFE = initUnsafe(); | ||||
|  | ||||
|         private static Unsafe initUnsafe() { | ||||
|             try { | ||||
|                 Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); | ||||
|                 theUnsafe.setAccessible(true); | ||||
|                 return (Unsafe) theUnsafe.get(Unsafe.class); | ||||
|             } | ||||
|             catch (NoSuchFieldException | IllegalAccessException e) { | ||||
|                 throw new RuntimeException(e); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         long pos, end; | ||||
|  | ||||
|         public Scanner(long start, long end) { | ||||
|             this.pos = start; | ||||
|             this.end = end; | ||||
|         } | ||||
|  | ||||
|         boolean hasNext() { | ||||
|             return pos < end; | ||||
|         } | ||||
|  | ||||
|         long pos() { | ||||
|             return pos; | ||||
|         } | ||||
|  | ||||
|         void add(int delta) { | ||||
|             pos += delta; | ||||
|         } | ||||
|  | ||||
|         long getLong() { | ||||
|             return UNSAFE.getLong(pos); | ||||
|         } | ||||
|  | ||||
|         long getLongAt(long pos) { | ||||
|             return UNSAFE.getLong(pos); | ||||
|         } | ||||
|  | ||||
|         public String getString(int nameLength) { | ||||
|             byte[] bytes = new byte[nameLength]; | ||||
|             UNSAFE.copyMemory(null, pos, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); | ||||
|             return new String(bytes, StandardCharsets.UTF_8); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user