Clean up, fine tuning, credit section for thomaswue (#646)
* Some clean up, fine tuning, removing non-supported options, added credit section and additional comments. * Put license header year back to 2023 to pass checks. * Remove static linking (as it requires some more setup on the target machine).
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							5b9703283a
						
					
				
				
					commit
					036f9a01b1
				
			| @@ -16,122 +16,68 @@ | ||||
| package dev.morling.onebrc; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.nio.ByteBuffer; | ||||
| import java.nio.ByteOrder; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.nio.file.Path; | ||||
| import java.nio.file.StandardOpenOption; | ||||
| import java.util.*; | ||||
| import java.util.concurrent.atomic.AtomicLong; | ||||
| import java.util.stream.IntStream; | ||||
|  | ||||
| /** | ||||
|  * Simple solution that memory maps the input file, then splits it into one segment per available core and uses | ||||
|  * sun.misc.Unsafe to directly access the mapped memory. Uses a long at a time when checking for collision. | ||||
|  * <p> | ||||
|  * Runs in 0.41s on my Intel i9-13900K | ||||
|  * Perf stats: | ||||
|  *     25,286,227,376      cpu_core/cycles/ | ||||
|  *     26,833,723,225      cpu_atom/cycles/ | ||||
|  * The solution starts a child worker process for the actual work such that clean up of the memory mapping can occur | ||||
|  * while the main process already returns with the result. The worker then memory maps the input file, creates a worker | ||||
|  * thread per available core, and then processes segments of size {@link #SEGMENT_SIZE} at a time. The segments are | ||||
|  * split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread. | ||||
|  * Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in | ||||
|  * the end. | ||||
|  * | ||||
|  * Runs in 0.40s on an Intel i9-13900K. | ||||
|  * | ||||
|  * Credit: | ||||
|  *  Quan Anh Mai for branchless number parsing code | ||||
|  *  Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea | ||||
|  *  Artsiom Korzun for showing the benefits of work stealing at 2MB segments instead of equal split between workers | ||||
|  */ | ||||
| public class CalculateAverage_thomaswue { | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|     private static final int MIN_TEMP = -999; | ||||
|     private static final int MAX_TEMP = 999; | ||||
|  | ||||
|     // Holding the current result for a single city. | ||||
|     private static class Result { | ||||
|         long lastNameLong, secondLastNameLong; | ||||
|         long min, max; | ||||
|         long sum; | ||||
|         int count; | ||||
|         long[] name; | ||||
|         String nameAsString; | ||||
|  | ||||
|         private Result() { | ||||
|             this.min = MAX_TEMP; | ||||
|             this.max = MIN_TEMP; | ||||
|         } | ||||
|  | ||||
|         public String toString() { | ||||
|             return round(((double) min) / 10.0) + "/" + round((((double) sum) / 10.0) / count) + "/" + round(((double) max) / 10.0); | ||||
|         } | ||||
|  | ||||
|         private static double round(double value) { | ||||
|             return Math.round(value * 10.0) / 10.0; | ||||
|         } | ||||
|  | ||||
|         // Accumulate another result into this one. | ||||
|         private void add(Result other) { | ||||
|             if (other.min < min) { | ||||
|                 min = other.min; | ||||
|             } | ||||
|             if (other.max > max) { | ||||
|                 max = other.max; | ||||
|             } | ||||
|             sum += other.sum; | ||||
|             count += other.count; | ||||
|         } | ||||
|  | ||||
|         public String calcName() { | ||||
|             if (nameAsString == null) { | ||||
|                 ByteBuffer bb = ByteBuffer.allocate(name.length * Long.BYTES).order(ByteOrder.nativeOrder()); | ||||
|                 bb.asLongBuffer().put(name); | ||||
|                 byte[] array = bb.array(); | ||||
|                 int i = 0; | ||||
|                 while (array[i++] != ';') | ||||
|                     ; | ||||
|                 nameAsString = new String(array, 0, i - 1, StandardCharsets.UTF_8); | ||||
|             } | ||||
|             return nameAsString; | ||||
|         } | ||||
|     } | ||||
|     private static final int MAX_NAME_LENGTH = 100; | ||||
|     private static final int MAX_CITIES = 10000; | ||||
|     private static final int SEGMENT_SIZE = 1 << 21; | ||||
|     private static final int HASH_TABLE_SIZE = 1 << 17; | ||||
|  | ||||
|     public static void main(String[] args) throws IOException, InterruptedException { | ||||
|         // Start worker subprocess if this process is not the worker. | ||||
|         if (args.length == 0 || !("--worker".equals(args[0]))) { | ||||
|             spawnWorker(); | ||||
|             return; | ||||
|         } | ||||
|         // Calculate input segments. | ||||
|  | ||||
|         int numberOfWorkers = Runtime.getRuntime().availableProcessors(); | ||||
|         final AtomicLong cursor = new AtomicLong(); | ||||
|         final long fileEnd; | ||||
|         final long fileStart; | ||||
|  | ||||
|         try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { | ||||
|         try (var fileChannel = FileChannel.open(java.nio.file.Path.of(FILE), java.nio.file.StandardOpenOption.READ)) { | ||||
|             long fileSize = fileChannel.size(); | ||||
|             fileStart = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, java.lang.foreign.Arena.global()).address(); | ||||
|             cursor.set(fileStart); | ||||
|             fileEnd = fileStart + fileSize; | ||||
|         } | ||||
|             final long fileStart = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, java.lang.foreign.Arena.global()).address(); | ||||
|             final long fileEnd = fileStart + fileSize; | ||||
|             final AtomicLong cursor = new AtomicLong(fileStart); | ||||
|  | ||||
|         // Parallel processing of segments. | ||||
|         Thread[] threads = new Thread[numberOfWorkers]; | ||||
|         List<Result>[] allResults = new List[numberOfWorkers]; | ||||
|         for (int i = 0; i < threads.length; ++i) { | ||||
|             final int index = i; | ||||
|             threads[i] = new Thread(() -> { | ||||
|                 Result[] resultArray = parseLoop(cursor, fileEnd, fileStart); | ||||
|                 List<Result> results = new ArrayList<>(500); | ||||
|                 for (Result r : resultArray) { | ||||
|                     if (r != null) { | ||||
|                         r.calcName(); | ||||
|                         results.add(r); | ||||
|                     } | ||||
|                 } | ||||
|                 allResults[index] = results; | ||||
|             }); | ||||
|             threads[i].start(); | ||||
|         } | ||||
|             // Parallel processing of segments. | ||||
|             Thread[] threads = new Thread[numberOfWorkers]; | ||||
|             List<Result>[] allResults = new List[numberOfWorkers]; | ||||
|             for (int i = 0; i < threads.length; ++i) { | ||||
|                 final int index = i; | ||||
|                 threads[i] = new Thread(() -> { | ||||
|                     List<Result> results = new ArrayList<>(MAX_CITIES); | ||||
|                     parseLoop(cursor, fileEnd, fileStart, results); | ||||
|                     allResults[index] = results; | ||||
|                 }); | ||||
|                 threads[i].start(); | ||||
|             } | ||||
|             for (Thread thread : threads) { | ||||
|                 thread.join(); | ||||
|             } | ||||
|  | ||||
|         for (Thread thread : threads) { | ||||
|             thread.join(); | ||||
|             // Final output. | ||||
|             System.out.println(accumulateResults(allResults)); | ||||
|             System.out.close(); | ||||
|         } | ||||
|  | ||||
|         // Final output. | ||||
|         System.out.println(accumulateResults(allResults)); | ||||
|         System.out.close(); | ||||
|     } | ||||
|  | ||||
|     private static void spawnWorker() throws IOException { | ||||
| @@ -144,31 +90,30 @@ public class CalculateAverage_thomaswue { | ||||
|                 .start().getInputStream().transferTo(System.out); | ||||
|     } | ||||
|  | ||||
|     // Accumulate results sequentially for simplicity. | ||||
|     private static TreeMap<String, Result> accumulateResults(List<Result>[] allResults) { | ||||
|         TreeMap<String, Result> result = new TreeMap<>(); | ||||
|         for (List<Result> resultArr : allResults) { | ||||
|             for (Result r : resultArr) { | ||||
|                 String name = r.calcName(); | ||||
|                 Result current = result.putIfAbsent(name, r); | ||||
|                 Result current = result.putIfAbsent(r.calcName(), r); | ||||
|                 if (current != null) { | ||||
|                     current.add(r); | ||||
|                     current.accumulate(r); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         return result; | ||||
|     } | ||||
|  | ||||
|     private static Result findResult(long initialWord, long initialPos, Scanner scanner, Result[] results) { | ||||
|  | ||||
|     private static Result findResult(long initialWord, long initialPos, Scanner scanner, Result[] results, List<Result> collectedResults) { | ||||
|         Result existingResult; | ||||
|         long word = initialWord; | ||||
|         long pos = initialPos; | ||||
|         long hash; | ||||
|         long nameAddress = scanner.pos(); | ||||
|  | ||||
|         // Search for ';', one long at a time. | ||||
|         // Search for ';', one long at a time. There are two common cases that a specially treated: | ||||
|         // (b) the ';' is found in the first 16 bytes | ||||
|         if (pos != 0) { | ||||
|             // Special case for when the ';' is found in the first 8 bytes. | ||||
|             pos = Long.numberOfTrailingZeros(pos) >>> 3; | ||||
|             scanner.add(pos); | ||||
|             word = mask(word, pos); | ||||
| @@ -180,11 +125,10 @@ public class CalculateAverage_thomaswue { | ||||
|             if (existingResult != null && existingResult.lastNameLong == word) { | ||||
|                 return existingResult; | ||||
|             } | ||||
|             else { | ||||
|                 scanner.setPos(nameAddress + pos); | ||||
|             } | ||||
|             scanner.setPos(nameAddress + pos); | ||||
|         } | ||||
|         else { | ||||
|             // Special case for when the ';' is found in bytes 9-16. | ||||
|             scanner.add(8); | ||||
|             hash = word; | ||||
|             long prevWord = word; | ||||
| @@ -201,11 +145,10 @@ public class CalculateAverage_thomaswue { | ||||
|                 if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { | ||||
|                     return existingResult; | ||||
|                 } | ||||
|                 else { | ||||
|                     scanner.setPos(nameAddress + pos + 8); | ||||
|                 } | ||||
|                 scanner.setPos(nameAddress + pos + 8); | ||||
|             } | ||||
|             else { | ||||
|                 // Slow-path for when the ';' could not be found in the first 16 bytes. | ||||
|                 scanner.add(8); | ||||
|                 hash ^= word; | ||||
|                 while (true) { | ||||
| @@ -234,20 +177,20 @@ public class CalculateAverage_thomaswue { | ||||
|         outer: while (true) { | ||||
|             existingResult = results[tableIndex]; | ||||
|             if (existingResult == null) { | ||||
|                 existingResult = newEntry(results, nameAddress, tableIndex, nameLength, scanner); | ||||
|                 existingResult = newEntry(results, nameAddress, tableIndex, nameLength, scanner, collectedResults); | ||||
|             } | ||||
|             // Check for collision. | ||||
|             int i = 0; | ||||
|             long[] name = existingResult.name; | ||||
|             for (; i < nameLength + 1 - 8; i += 8) { | ||||
|                 if (scanner.getLongAt(i, name) != scanner.getLongAt(nameAddress + i)) { | ||||
|                 if (scanner.getLongAt(existingResult.nameAddress + i) != scanner.getLongAt(nameAddress + i)) { | ||||
|                     // Collision error, try next. | ||||
|                     tableIndex = (tableIndex + 31) & (results.length - 1); | ||||
|                     continue outer; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             int remainingShift = (64 - (nameLength + 1 - i) << 3); | ||||
|             if (((existingResult.lastNameLong ^ (scanner.getLongAt(nameAddress + i) << remainingShift)) == 0)) { | ||||
|             if (existingResult.lastNameLong == (scanner.getLongAt(nameAddress + i) << remainingShift)) { | ||||
|                 break; | ||||
|             } | ||||
|             else { | ||||
| @@ -258,7 +201,7 @@ public class CalculateAverage_thomaswue { | ||||
|         return existingResult; | ||||
|     } | ||||
|  | ||||
|     private static long nextNL(long prev) { | ||||
|     private static long nextNewLine(long prev) { | ||||
|         while (true) { | ||||
|             long currentWord = Scanner.UNSAFE.getLong(prev); | ||||
|             long pos = findNewLine(currentWord); | ||||
| @@ -273,11 +216,9 @@ public class CalculateAverage_thomaswue { | ||||
|         return prev; | ||||
|     } | ||||
|  | ||||
|     private static final int SEGMENT_SIZE = 1024 * 1024 * 2; | ||||
|  | ||||
|     // Main parse loop. | ||||
|     private static Result[] parseLoop(AtomicLong counter, long fileEnd, long fileStart) { | ||||
|         Result[] results = new Result[1 << 17]; | ||||
|     private static Result[] parseLoop(AtomicLong counter, long fileEnd, long fileStart, List<Result> collectedResults) { | ||||
|         Result[] results = new Result[HASH_TABLE_SIZE]; | ||||
|  | ||||
|         while (true) { | ||||
|             long current = counter.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE; | ||||
| @@ -286,18 +227,18 @@ public class CalculateAverage_thomaswue { | ||||
|                 return results; | ||||
|             } | ||||
|  | ||||
|             long segmentEnd = nextNL(Math.min(fileEnd - 1, current + SEGMENT_SIZE)); | ||||
|             long segmentEnd = nextNewLine(Math.min(fileEnd - 1, current + SEGMENT_SIZE)); | ||||
|             long segmentStart; | ||||
|             if (current == fileStart) { | ||||
|                 segmentStart = current; | ||||
|             } | ||||
|             else { | ||||
|                 segmentStart = nextNL(current) + 1; | ||||
|                 segmentStart = nextNewLine(current) + 1; | ||||
|             } | ||||
|  | ||||
|             long dist = (segmentEnd - segmentStart) / 3; | ||||
|             long midPoint1 = nextNL(segmentStart + dist); | ||||
|             long midPoint2 = nextNL(segmentStart + dist + dist); | ||||
|             long midPoint1 = nextNewLine(segmentStart + dist); | ||||
|             long midPoint2 = nextNewLine(segmentStart + dist + dist); | ||||
|  | ||||
|             Scanner scanner1 = new Scanner(segmentStart, midPoint1); | ||||
|             Scanner scanner2 = new Scanner(midPoint1 + 1, midPoint2); | ||||
| @@ -319,9 +260,9 @@ public class CalculateAverage_thomaswue { | ||||
|                 long pos1 = findDelimiter(word1); | ||||
|                 long pos2 = findDelimiter(word2); | ||||
|                 long pos3 = findDelimiter(word3); | ||||
|                 Result existingResult1 = findResult(word1, pos1, scanner1, results); | ||||
|                 Result existingResult2 = findResult(word2, pos2, scanner2, results); | ||||
|                 Result existingResult3 = findResult(word3, pos3, scanner3, results); | ||||
|                 Result existingResult1 = findResult(word1, pos1, scanner1, results, collectedResults); | ||||
|                 Result existingResult2 = findResult(word2, pos2, scanner2, results, collectedResults); | ||||
|                 Result existingResult3 = findResult(word3, pos3, scanner3, results, collectedResults); | ||||
|                 long number1 = scanNumber(scanner1); | ||||
|                 long number2 = scanNumber(scanner2); | ||||
|                 long number3 = scanNumber(scanner3); | ||||
| @@ -333,19 +274,19 @@ public class CalculateAverage_thomaswue { | ||||
|             while (scanner1.hasNext()) { | ||||
|                 long word = scanner1.getLong(); | ||||
|                 long pos = findDelimiter(word); | ||||
|                 record(findResult(word, pos, scanner1, results), scanNumber(scanner1)); | ||||
|                 record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1)); | ||||
|             } | ||||
|  | ||||
|             while (scanner2.hasNext()) { | ||||
|                 long word = scanner2.getLong(); | ||||
|                 long pos = findDelimiter(word); | ||||
|                 record(findResult(word, pos, scanner2, results), scanNumber(scanner2)); | ||||
|                 record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2)); | ||||
|             } | ||||
|  | ||||
|             while (scanner3.hasNext()) { | ||||
|                 long word = scanner3.getLong(); | ||||
|                 long pos = findDelimiter(word); | ||||
|                 record(findResult(word, pos, scanner3, results), scanNumber(scanner3)); | ||||
|                 record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| @@ -361,10 +302,10 @@ public class CalculateAverage_thomaswue { | ||||
|  | ||||
|     private static void record(Result existingResult, long number) { | ||||
|         if (number < existingResult.min) { | ||||
|             existingResult.min = number; | ||||
|             existingResult.min = (short) number; | ||||
|         } | ||||
|         if (number > existingResult.max) { | ||||
|             existingResult.max = number; | ||||
|             existingResult.max = (short) number; | ||||
|         } | ||||
|         existingResult.sum += number; | ||||
|         existingResult.count++; | ||||
| @@ -406,31 +347,71 @@ public class CalculateAverage_thomaswue { | ||||
|         return tmp; | ||||
|     } | ||||
|  | ||||
|     private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner) { | ||||
|     private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List<Result> collectedResults) { | ||||
|         Result r = new Result(); | ||||
|         results[hash] = r; | ||||
|         long[] name = new long[(nameLength / Long.BYTES) + 1]; | ||||
|         int pos = 0; | ||||
|         int i = 0; | ||||
|         for (; i < nameLength + 1 - Long.BYTES; i += Long.BYTES) { | ||||
|             name[pos++] = scanner.getLongAt(nameAddress + i); | ||||
|         } | ||||
|  | ||||
|         if (pos > 0) { | ||||
|             r.secondLastNameLong = name[pos - 1]; | ||||
|         if (nameLength + 1 > 8) { | ||||
|             r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8); | ||||
|         } | ||||
|  | ||||
|         int remainingShift = (64 - (nameLength + 1 - i) << 3); | ||||
|         long lastWord = (scanner.getLongAt(nameAddress + i) << remainingShift); | ||||
|         r.lastNameLong = lastWord; | ||||
|         name[pos] = lastWord >> remainingShift; | ||||
|         r.name = name; | ||||
|         r.nameAddress = nameAddress; | ||||
|         collectedResults.add(r); | ||||
|         return r; | ||||
|     } | ||||
|  | ||||
|     private static class Scanner { | ||||
|     private static class Result { | ||||
|         long lastNameLong, secondLastNameLong; | ||||
|         short min, max; | ||||
|         int count; | ||||
|         long sum; | ||||
|         long nameAddress; | ||||
|  | ||||
|         private Result() { | ||||
|             this.min = MAX_TEMP; | ||||
|             this.max = MIN_TEMP; | ||||
|         } | ||||
|  | ||||
|         public String toString() { | ||||
|             return round(((double) min) / 10.0) + "/" + round((((double) sum) / 10.0) / count) + "/" + round(((double) max) / 10.0); | ||||
|         } | ||||
|  | ||||
|         private static double round(double value) { | ||||
|             return Math.round(value * 10.0) / 10.0; | ||||
|         } | ||||
|  | ||||
|         private void accumulate(Result other) { | ||||
|             if (other.min < min) { | ||||
|                 min = other.min; | ||||
|             } | ||||
|             if (other.max > max) { | ||||
|                 max = other.max; | ||||
|             } | ||||
|             sum += other.sum; | ||||
|             count += other.count; | ||||
|         } | ||||
|  | ||||
|         public String calcName() { | ||||
|             Scanner scanner = new Scanner(nameAddress, nameAddress + MAX_NAME_LENGTH + 1); | ||||
|             int nameLength = 0; | ||||
|             while (scanner.getByteAt(nameAddress + nameLength) != ';') { | ||||
|                 nameLength++; | ||||
|             } | ||||
|             byte[] array = new byte[nameLength]; | ||||
|             for (int i = 0; i < nameLength; ++i) { | ||||
|                 array[i] = scanner.getByteAt(nameAddress + i); | ||||
|             } | ||||
|             return new String(array, java.nio.charset.StandardCharsets.UTF_8); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static class Scanner { | ||||
|         private static final sun.misc.Unsafe UNSAFE = initUnsafe(); | ||||
|         private long pos, end; | ||||
|  | ||||
|         private static sun.misc.Unsafe initUnsafe() { | ||||
|             try { | ||||
| @@ -443,8 +424,6 @@ public class CalculateAverage_thomaswue { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         long pos, end; | ||||
|  | ||||
|         public Scanner(long start, long end) { | ||||
|             this.pos = start; | ||||
|             this.end = end; | ||||
| @@ -470,6 +449,10 @@ public class CalculateAverage_thomaswue { | ||||
|             return UNSAFE.getLong(pos); | ||||
|         } | ||||
|  | ||||
|         byte getByteAt(long pos) { | ||||
|             return UNSAFE.getByte(pos); | ||||
|         } | ||||
|  | ||||
|         long getLongAt(long pos, long[] array) { | ||||
|             return UNSAFE.getLong(array, pos + sun.misc.Unsafe.ARRAY_LONG_BASE_OFFSET); | ||||
|         } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user