Use SIMD for search for delimiter and name compare
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							9e9e533401
						
					
				
				
					commit
					243388ad7b
				
			| @@ -25,29 +25,33 @@ import java.nio.channels.FileChannel.MapMode; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.nio.file.Path; | ||||
| import java.nio.file.StandardOpenOption; | ||||
| import java.util.HashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import java.util.TreeMap; | ||||
| import java.util.*; | ||||
| import java.util.stream.IntStream; | ||||
|  | ||||
| /** | ||||
|  * Simple solution that memory maps the input file, then splits it into one segment per available core and uses | ||||
|  * sun.misc.Unsafe to directly access the mapped memory. | ||||
|  * | ||||
|  * Runs in 0.92s on my Intel i9-13900K | ||||
|  * Perf stats: | ||||
|  *     65,004,666,383      cpu_core/cycles/ | ||||
|  *     71,141,249,972      cpu_atom/cycles/ | ||||
|  */ | ||||
| public class CalculateAverage_thomaswue { | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|  | ||||
|     // Holding the current result for a single city. | ||||
|     private static class Result { | ||||
|         int min; | ||||
|         int max; | ||||
|         short min; | ||||
|         short max; | ||||
|         long sum; | ||||
|         int count; | ||||
|         final long nameAddress; | ||||
|         final int nameLength; | ||||
|  | ||||
|         private Result(long nameAddress, int nameLength, int value) { | ||||
|         private Result(long nameAddress, int value) { | ||||
|             this.nameAddress = nameAddress; | ||||
|             this.nameLength = nameLength; | ||||
|             this.min = value; | ||||
|             this.max = value; | ||||
|             this.min = (short) value; | ||||
|             this.max = (short) value; | ||||
|             this.sum = value; | ||||
|             this.count = 1; | ||||
|         } | ||||
| @@ -62,8 +66,8 @@ public class CalculateAverage_thomaswue { | ||||
|  | ||||
|         // Accumulate another result into this one. | ||||
|         private void add(Result other) { | ||||
|             min = Math.min(min, other.min); | ||||
|             max = Math.max(max, other.max); | ||||
|             min = (short) Math.min(min, other.min); | ||||
|             max = (short) Math.max(max, other.max); | ||||
|             sum += other.sum; | ||||
|             count += other.count; | ||||
|         } | ||||
| @@ -77,7 +81,7 @@ public class CalculateAverage_thomaswue { | ||||
|         // Parallel processing of segments. | ||||
|         List<HashMap<String, Result>> allResults = IntStream.range(0, chunks.length - 1).mapToObj(chunkIndex -> { | ||||
|             HashMap<String, Result> cities = HashMap.newHashMap(1 << 10); | ||||
|             Result[] results = new Result[1 << 14]; | ||||
|             Result[] results = new Result[1 << 18]; | ||||
|             parseLoop(chunks[chunkIndex], chunks[chunkIndex + 1], results, cities); | ||||
|             return cities; | ||||
|         }).parallel().toList(); | ||||
| @@ -86,13 +90,10 @@ public class CalculateAverage_thomaswue { | ||||
|         HashMap<String, Result> result = allResults.getFirst(); | ||||
|         for (int i = 1; i < allResults.size(); ++i) { | ||||
|             for (Map.Entry<String, Result> entry : allResults.get(i).entrySet()) { | ||||
|                 Result current = result.get(entry.getKey()); | ||||
|                 Result current = result.putIfAbsent(entry.getKey(), entry.getValue()); | ||||
|                 if (current != null) { | ||||
|                     current.add(entry.getValue()); | ||||
|                 } | ||||
|                 else { | ||||
|                     result.put(entry.getKey(), entry.getValue()); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
| @@ -113,34 +114,46 @@ public class CalculateAverage_thomaswue { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     static boolean unsafeEquals(long aStart, long aLength, long bStart, long bLength) { | ||||
|         if (aLength != bLength) { | ||||
|             return false; | ||||
|         } | ||||
|         for (int i = 0; i < aLength; ++i) { | ||||
|             if (UNSAFE.getByte(aStart + i) != UNSAFE.getByte(bStart + i)) { | ||||
|                 return false; | ||||
|             } | ||||
|         } | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     private static void parseLoop(long chunkStart, long chunkEnd, Result[] results, HashMap<String, Result> cities) { | ||||
|         long scanPtr = chunkStart; | ||||
|         byte b; | ||||
|         while (scanPtr < chunkEnd) { | ||||
|             long nameAddress = scanPtr; | ||||
|             int hash = 0; | ||||
|  | ||||
|             int hash = UNSAFE.getByte(scanPtr++); | ||||
|             while ((b = UNSAFE.getByte(scanPtr++)) != ';') { | ||||
|                 hash += b; | ||||
|                 hash += hash << 10; | ||||
|                 hash ^= hash >> 6; | ||||
|             // Skip first letter. | ||||
|             scanPtr++; | ||||
|  | ||||
|             // Scan for ';' delimiter, always 4 bytes at a time. | ||||
|             while (true) { | ||||
|                 int nextVal = UNSAFE.getInt(scanPtr); | ||||
|                 if ((nextVal & 0x3B) == 0x3B) { | ||||
|                     scanPtr++; | ||||
|                     break; | ||||
|                 } | ||||
|                 else if ((nextVal & 0x3B00) == 0x3B00) { | ||||
|                     scanPtr += 2; | ||||
|                     hash = hash ^ (nextVal & 0xFF); | ||||
|                     break; | ||||
|                 } | ||||
|                 else if ((nextVal & 0x3B0000) == 0x3B0000) { | ||||
|                     scanPtr += 3; | ||||
|                     hash = hash ^ (nextVal & 0xFFFF); | ||||
|                     break; | ||||
|                 } | ||||
|                 else if (((nextVal & 0x3B000000) == 0x3B000000)) { | ||||
|                     scanPtr += 4; | ||||
|                     hash = hash ^ (nextVal & 0xFFFFFF); | ||||
|                     break; | ||||
|                 } | ||||
|                 scanPtr += 4; | ||||
|                 hash = hash ^ nextVal; | ||||
|             } | ||||
|  | ||||
|             int nameLength = (int) (scanPtr - 1 - nameAddress); | ||||
|             hash = hash & (results.length - 1); | ||||
|             // Save length of name for later. | ||||
|             int nameLength = (int) (scanPtr - nameAddress - 1); | ||||
|  | ||||
|             // Parse number. | ||||
|             int number; | ||||
|             byte sign = UNSAFE.getByte(scanPtr++); | ||||
|             if (sign == '-') { | ||||
| @@ -161,26 +174,53 @@ public class CalculateAverage_thomaswue { | ||||
|                 number = number * 10 + (UNSAFE.getByte(scanPtr++) - '0'); | ||||
|             } | ||||
|  | ||||
|             // Final calculation for index into hash table. | ||||
|             int tableIndex = (((hash ^ (hash >>> 18)) & (results.length - 1))); | ||||
|             while (true) { | ||||
|                 Result existingResult = results[hash]; | ||||
|                 Result existingResult = results[tableIndex]; | ||||
|                 if (existingResult == null) { | ||||
|                     Result r = new Result(nameAddress, nameLength, number); | ||||
|                     results[hash] = r; | ||||
|                     byte[] bytes = new byte[nameLength]; | ||||
|                     UNSAFE.copyMemory(null, nameAddress, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); | ||||
|                     cities.put(new String(bytes, StandardCharsets.UTF_8), r); | ||||
|                     break; | ||||
|                 } | ||||
|                 else if (unsafeEquals(existingResult.nameAddress, existingResult.nameLength, nameAddress, nameLength)) { | ||||
|                     existingResult.min = Math.min(existingResult.min, number); | ||||
|                     existingResult.max = Math.max(existingResult.max, number); | ||||
|                     existingResult.sum += number; | ||||
|                     existingResult.count++; | ||||
|                     newEntry(results, cities, nameAddress, number, tableIndex, nameLength); | ||||
|                     break; | ||||
|                 } | ||||
|                 else { | ||||
|                     // Collision error, try next. | ||||
|                     hash = (hash + 1) & (results.length - 1); | ||||
|                     // Check for collision. | ||||
|                     boolean result = true; | ||||
|                     int i = 0; | ||||
|                     if ((long) nameLength >= 8) { | ||||
|                         if (UNSAFE.getLong(existingResult.nameAddress) != UNSAFE.getLong(nameAddress)) { | ||||
|                             result = false; | ||||
|                         } | ||||
|                         else { | ||||
|                             i += 8; | ||||
|                         } | ||||
|                     } | ||||
|                     else if ((long) nameLength >= 4) { | ||||
|                         if (UNSAFE.getInt(existingResult.nameAddress) != UNSAFE.getInt(nameAddress)) { | ||||
|                             result = false; | ||||
|                         } | ||||
|                         else { | ||||
|                             i += 4; | ||||
|                         } | ||||
|                     } | ||||
|                     if (result) { | ||||
|                         for (; i < (long) nameLength; ++i) { | ||||
|                             if (UNSAFE.getByte(existingResult.nameAddress + i) != UNSAFE.getByte(nameAddress + i)) { | ||||
|                                 result = false; | ||||
|                                 break; | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                     if (result) { | ||||
|                         existingResult.min = (short) Math.min(existingResult.min, number); | ||||
|                         existingResult.max = (short) Math.max(existingResult.max, number); | ||||
|                         existingResult.sum += number; | ||||
|                         existingResult.count++; | ||||
|                         break; | ||||
|                     } | ||||
|                     else { | ||||
|                         // Collision error, try next. | ||||
|                         tableIndex = (tableIndex + 1) & (results.length - 1); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|  | ||||
| @@ -189,6 +229,14 @@ public class CalculateAverage_thomaswue { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static void newEntry(Result[] results, HashMap<String, Result> cities, long nameAddress, int number, int hash, int nameLength) { | ||||
|         Result r = new Result(nameAddress, number); | ||||
|         results[hash] = r; | ||||
|         byte[] bytes = new byte[nameLength]; | ||||
|         UNSAFE.copyMemory(null, nameAddress, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); | ||||
|         cities.put(new String(bytes, StandardCharsets.UTF_8), r); | ||||
|     } | ||||
|  | ||||
|     private static long[] getSegments(int numberOfChunks) throws IOException { | ||||
|         try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { | ||||
|             long fileSize = fileChannel.size(); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user