diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_vemana.java b/src/main/java/dev/morling/onebrc/CalculateAverage_vemana.java index 8f690e3..3e64ac9 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_vemana.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_vemana.java @@ -171,7 +171,7 @@ public class CalculateAverage_vemana { int chunkSizeBits = 20; // For the last commonChunkFraction fraction of total work, use smaller chunk sizes - double commonChunkFraction = 0; + double commonChunkFraction = 0.03; // Use commonChunkSizeBits for the small-chunk size int commonChunkSizeBits = 18; @@ -181,11 +181,17 @@ public class CalculateAverage_vemana { int minReservedBytesAtFileTail = 9; + int nThreads = -1; + String inputFile = "measurements.txt"; + double munmapFraction = 0.03; + + boolean fakeAdvance = false; + for (String arg : args) { - String key = arg.substring(0, arg.indexOf('=')); - String value = arg.substring(key.length() + 1); + String key = arg.substring(0, arg.indexOf('=')).trim(); + String value = arg.substring(key.length() + 1).trim(); switch (key) { case "chunkSizeBits": chunkSizeBits = Integer.parseInt(value); @@ -202,6 +208,15 @@ public class CalculateAverage_vemana { case "inputfile": inputFile = value; break; + case "munmapFraction": + munmapFraction = Double.parseDouble(value); + break; + case "fakeAdvance": + fakeAdvance = Boolean.parseBoolean(value); + break; + case "nThreads": + nThreads = Integer.parseInt(value); + break; default: throw new IllegalArgumentException("Unknown argument: " + arg); } @@ -218,14 +233,17 @@ public class CalculateAverage_vemana { System.out.println( new Runner( Path.of(inputFile), + nThreads, chunkSizeBits, commonChunkFraction, commonChunkSizeBits, hashtableSizeBits, - minReservedBytesAtFileTail) + minReservedBytesAtFileTail, + munmapFraction, + fakeAdvance) .getSummaryStatistics()); - Tracing.recordEvent("After printing result"); + Tracing.recordEvent("Final result printed"); } public record AggregateResult(Map tempStats) { @@ -286,8 +304,8 @@ public class CalculateAverage_vemana { bufferEnd = bufferStart = -1; } - public void close(int shardIdx) { - Tracing.recordWorkStart("cleaner", shardIdx); + public void close(String closerId, int shardIdx) { + Tracing.recordWorkStart(closerId, shardIdx); if (byteBuffer != null) { unclosedBuffers.add(byteBuffer); } @@ -297,7 +315,7 @@ public class CalculateAverage_vemana { unclosedBuffers.clear(); bufferEnd = bufferStart = -1; byteBuffer = null; - Tracing.recordWorkEnd("cleaner", shardIdx); + Tracing.recordWorkEnd(closerId, shardIdx); } public void setRange(long rangeStart, long rangeEnd) { @@ -383,7 +401,7 @@ public class CalculateAverage_vemana { public interface LazyShardQueue { - void close(int shardIdx); + void close(String closerId, int shardIdx); Optional fileTailEndWork(int idx); @@ -415,37 +433,48 @@ public class CalculateAverage_vemana { private final double commonChunkFraction; private final int commonChunkSizeBits; + private final boolean fakeAdvance; private final int hashtableSizeBits; private final Path inputFile; private final int minReservedBytesAtFileTail; + private final double munmapFraction; + private final int nThreads; private final int shardSizeBits; public Runner( Path inputFile, + int nThreads, int chunkSizeBits, double commonChunkFraction, int commonChunkSizeBits, int hashtableSizeBits, - int minReservedBytesAtFileTail) { + int minReservedBytesAtFileTail, + double munmapFraction, + boolean fakeAdvance) { this.inputFile = inputFile; + this.nThreads = nThreads; this.shardSizeBits = chunkSizeBits; this.commonChunkFraction = commonChunkFraction; this.commonChunkSizeBits = commonChunkSizeBits; this.hashtableSizeBits = hashtableSizeBits; this.minReservedBytesAtFileTail = minReservedBytesAtFileTail; + this.munmapFraction = munmapFraction; + this.fakeAdvance = fakeAdvance; } AggregateResult getSummaryStatistics() throws Exception { - int nThreads = Runtime.getRuntime().availableProcessors(); + int nThreads = this.nThreads < 0 ? Runtime.getRuntime().availableProcessors() : this.nThreads; + LazyShardQueue shardQueue = new SerialLazyShardQueue( 1L << shardSizeBits, inputFile, nThreads, commonChunkFraction, commonChunkSizeBits, - minReservedBytesAtFileTail); + minReservedBytesAtFileTail, + munmapFraction, + fakeAdvance); - List> results = new ArrayList<>(); ExecutorService executorService = Executors.newFixedThreadPool( nThreads, runnable -> { @@ -454,42 +483,56 @@ public class CalculateAverage_vemana { return thread; }); + List> results = new ArrayList<>(); for (int i = 0; i < nThreads; i++) { final int shardIdx = i; final Callable callable = () -> { - Tracing.recordWorkStart("shard", shardIdx); + Tracing.recordWorkStart("Shard", shardIdx); AggregateResult result = new ShardProcessor(shardQueue, hashtableSizeBits, shardIdx).processShard(); - Tracing.recordWorkEnd("shard", shardIdx); + Tracing.recordWorkEnd("Shard", shardIdx); return result; }; results.add(executorService.submit(callable)); } Tracing.recordEvent("Basic push time"); - AggregateResult result = executorService.submit(() -> merge(results)).get(); + // This particular sequence of Futures is so that both merge and munmap() can work as shards + // finish their computation without blocking on the entire set of shards to complete. In + // particular, munmap() doesn't need to wait on merge. + // First, submit a task to merge the results and then submit a task to cleanup bytebuffers + // from completed shards. + Future resultFutures = executorService.submit(() -> merge(results)); + // Note that munmap() is serial and not parallel and hence we use just one thread. + executorService.submit(() -> closeByteBuffers(results, shardQueue)); + AggregateResult result = resultFutures.get(); Tracing.recordEvent("Merge results received"); - // Note that munmap() is serial and not parallel - executorService.submit( - () -> { - for (int i = 0; i < nThreads; i++) { - shardQueue.close(i); - } - }); - - Tracing.recordEvent("Waiting for executor shutdown"); - + Tracing.recordEvent("About to shutdown executor and wait"); executorService.shutdown(); executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS); - Tracing.recordEvent("Executor terminated"); - Tracing.analyzeWorkThreads("cleaner", nThreads); - Tracing.recordEvent("After cleaner finish printed"); + Tracing.analyzeWorkThreads(nThreads); return result; } + private void closeByteBuffers( + List> results, LazyShardQueue shardQueue) { + int n = results.size(); + boolean[] isDone = new boolean[n]; + int remaining = results.size(); + while (remaining > 0) { + for (int i = 0; i < n; i++) { + if (!isDone[i] && results.get(i).isDone()) { + remaining--; + isDone[i] = true; + shardQueue.close("Ending Cleaner", i); + } + } + } + } + private AggregateResult merge(List> results) throws ExecutionException, InterruptedException { Tracing.recordEvent("Merge start time"); @@ -516,7 +559,6 @@ public class CalculateAverage_vemana { } } Tracing.recordEvent("Merge end time"); - Tracing.analyzeWorkThreads("shard", results.size()); return new AggregateResult(output); } } @@ -532,6 +574,7 @@ public class CalculateAverage_vemana { private final long commonChunkSize; private final AtomicLong commonPool; private final long effectiveFileSize; + private final boolean fakeAdvance; private final long fileSize; private final long[] perThreadData; private final RandomAccessFile raf; @@ -543,8 +586,11 @@ public class CalculateAverage_vemana { int shards, double commonChunkFraction, int commonChunkSizeBits, - int fileTailReservedBytes) + int fileTailReservedBytes, + double munmapFraction, + boolean fakeAdvance) throws IOException { + this.fakeAdvance = fakeAdvance; Checks.checkArg(commonChunkFraction < 0.9 && commonChunkFraction >= 0); Checks.checkArg(fileTailReservedBytes >= 0); this.raf = new RandomAccessFile(filePath.toFile(), "r"); @@ -580,8 +626,8 @@ public class CalculateAverage_vemana { // its work, where R = relative speed of unmap compared to the computation. // For our problem, R ~ 75 because unmap unmaps 30GB/sec (but, it is serial) while // cores go through data at the rate of 400MB/sec. - perThreadData[pos + 3] = (long) (currentChunks * (0.03 * (shards - i))); - perThreadData[pos + 4] = 1; + perThreadData[pos + 3] = (long) (currentChunks * (munmapFraction * (shards - i))); + perThreadData[pos + 4] = 1; // true iff munmap() hasn't been triggered yet currentStart += currentChunks * chunkSize; remainingChunks -= currentChunks; } @@ -596,8 +642,8 @@ public class CalculateAverage_vemana { } @Override - public void close(int shardIdx) { - byteRanges[shardIdx << 4].close(shardIdx); + public void close(String closerId, int shardIdx) { + byteRanges[shardIdx << 4].close(closerId, shardIdx); } @Override @@ -616,14 +662,18 @@ public class CalculateAverage_vemana { public ByteRange take(int shardIdx) { // Try for thread local range final int pos = shardIdx << 4; - long rangeStart = perThreadData[pos]; - final long chunkEnd = perThreadData[pos + 1]; + final long rangeStart; final long rangeEnd; - if (rangeStart < chunkEnd) { + if (perThreadData[pos + 2] >= 1) { + rangeStart = perThreadData[pos]; rangeEnd = rangeStart + chunkSize; - perThreadData[pos] = rangeEnd; + // Don't do this in the if-check; it causes negative values that trigger intermediate + // cleanup perThreadData[pos + 2]--; + if (!fakeAdvance) { + perThreadData[pos] = rangeEnd; + } } else { rangeStart = commonPool.getAndAdd(commonChunkSize); @@ -634,8 +684,8 @@ public class CalculateAverage_vemana { rangeEnd = rangeStart + commonChunkSize; } - if (perThreadData[pos + 2] <= perThreadData[pos + 3] && perThreadData[pos + 4] > 0) { - if (attemptClose(shardIdx)) { + if (perThreadData[pos + 2] < perThreadData[pos + 3] && perThreadData[pos + 4] > 0) { + if (attemptIntermediateClose(shardIdx)) { perThreadData[pos + 4]--; } } @@ -645,9 +695,9 @@ public class CalculateAverage_vemana { return chunk; } - private boolean attemptClose(int shardIdx) { + private boolean attemptIntermediateClose(int shardIdx) { if (seqLock.acquire()) { - byteRanges[shardIdx << 4].close(shardIdx); + close("Intermediate Cleaner", shardIdx); seqLock.release(); return true; } @@ -964,12 +1014,22 @@ public class CalculateAverage_vemana { static class Tracing { - private static final long[] cleanerTimes = new long[1 << 6 << 1]; - private static final long[] threadTimes = new long[1 << 6 << 1]; + private static final Map knownWorkThreadEvents; private static long startTime; - static void analyzeWorkThreads(String id, int nThreads) { - printTimingsAnalysis(id + " Stats", nThreads, timingsArray(id)); + static { + // Maintain the ordering to be chronological in execution + // Map.of(..) screws up ordering + knownWorkThreadEvents = new LinkedHashMap<>(); + for (String id : List.of("Shard", "Intermediate Cleaner", "Ending Cleaner")) { + knownWorkThreadEvents.put(id, new ThreadTimingsArray(id, 1 << 6 << 1)); + } + } + + static void analyzeWorkThreads(int nThreads) { + for (ThreadTimingsArray array : knownWorkThreadEvents.values()) { + errPrint(array.analyze(nThreads)); + } } static void recordAppStart() { @@ -981,11 +1041,11 @@ public class CalculateAverage_vemana { } static void recordWorkEnd(String id, int threadId) { - timingsArray(id)[2 * threadId + 1] = System.nanoTime(); + knownWorkThreadEvents.get(id).recordEnd(threadId); } static void recordWorkStart(String id, int threadId) { - timingsArray(id)[2 * threadId] = System.nanoTime(); + knownWorkThreadEvents.get(id).recordStart(threadId); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -998,57 +1058,78 @@ public class CalculateAverage_vemana { errPrint(STR."\{message} = \{(nanoTime - startTime) / 1_000_000}ms"); } - private static void printTimingsAnalysis(String header, int nThreads, long[] timestamps) { - long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE; - long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE; - long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE; + public static class ThreadTimingsArray { - long[] durationsMs = new long[nThreads]; - long[] completionsMs = new long[nThreads]; - long[] beginMs = new long[nThreads]; - for (int i = 0; i < nThreads; i++) { - long durationNs = timestamps[2 * i + 1] - timestamps[2 * i]; - durationsMs[i] = durationNs / 1_000_000; - completionsMs[i] = (timestamps[2 * i + 1] - startTime) / 1_000_000; - beginMs[i] = (timestamps[2 * i] - startTime) / 1_000_000; + private static String toString(long[] array) { + return Arrays.stream(array) + .map(x -> x < 0 ? -1 : x) + .mapToObj(x -> String.format("%6d", x)) + .collect(Collectors.joining(", ", "[ ", " ]")); + } - minDuration = Math.min(minDuration, durationNs); - maxDuration = Math.max(maxDuration, durationNs); + private final String id; + private final long[] timestamps; + private boolean hasData = false; - minBegin = Math.min(minBegin, timestamps[2 * i]); - maxBegin = Math.max(maxBegin, timestamps[2 * i]); + public ThreadTimingsArray(String id, int maxSize) { + this.timestamps = new long[maxSize]; + this.id = id; + } - maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1]); - minCompletion = Math.min(minCompletion, timestamps[2 * i + 1]); - } - errPrint( - STR.""" + public String analyze(int nThreads) { + if (!hasData) { + return "%s has no thread timings data".formatted(id); + } + Checks.checkArg(nThreads <= timestamps.length); + long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE; + long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE; + long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE; + + long[] durationsMs = new long[nThreads]; + long[] completionsMs = new long[nThreads]; + long[] beginMs = new long[nThreads]; + for (int i = 0; i < nThreads; i++) { + long durationNs = timestamps[2 * i + 1] - timestamps[2 * i]; + durationsMs[i] = durationNs / 1_000_000; + completionsMs[i] = (timestamps[2 * i + 1] - startTime) / 1_000_000; + beginMs[i] = (timestamps[2 * i] - startTime) / 1_000_000; + + minDuration = Math.min(minDuration, durationNs); + maxDuration = Math.max(maxDuration, durationNs); + + minBegin = Math.min(minBegin, timestamps[2 * i] - startTime); + maxBegin = Math.max(maxBegin, timestamps[2 * i] - startTime); + + maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1] - startTime); + minCompletion = Math.min(minCompletion, timestamps[2 * i + 1] - startTime); + } + return STR.""" ------------------------------------------------------------------------------------------- - \{header} + \{id} Stats ------------------------------------------------------------------------------------------- Max duration = \{maxDuration / 1_000_000} ms Min duration = \{minDuration / 1_000_000} ms - Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms + Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms [\{maxCompletion / 1_000_000} - \{minBegin / 1_000_000} ] Completion Timespan[max(end)-min(end)] = \{(maxCompletion - minCompletion) / 1_000_000} ms Begin Timespan[max(begin)-min(begin)] = \{(maxBegin - minBegin) / 1_000_000} ms - Durations = \{toString(durationsMs)} in ms - Begin Timestamps = \{toString(beginMs)} in ms - Completion Timestamps = \{toString(completionsMs)} in ms - """); - } + Average Duration = \{Arrays.stream(durationsMs) + .average() + .getAsDouble()} ms + Durations = \{toString(durationsMs)} ms + Begin Timestamps = \{toString(beginMs)} ms + Completion Timestamps = \{toString(completionsMs)} ms + """; + } - private static long[] timingsArray(String id) { - return switch (id) { - case "cleaner" -> cleanerTimes; - case "shard" -> threadTimes; - default -> throw new RuntimeException(""); - }; - } + public void recordEnd(int idx) { + timestamps[2 * idx + 1] = System.nanoTime(); + hasData = true; + } - private static String toString(long[] array) { - return Arrays.stream(array) - .mapToObj(x -> String.format("%6d", x)) - .collect(Collectors.joining(", ", "[ ", " ]")); + public void recordStart(int idx) { + timestamps[2 * idx] = System.nanoTime(); + hasData = true; + } } } }