Reduce variance by (1) Using common chunks at the end (2) Busy looping (#486)

on automatic closing of ByteBuffers.. previously, a straggler could hold
up closing the ByteBuffers.

Also
- Improve Tracing code
- Parametrize additional options to aid in tuning

Our previous PR was surprising; parallelizing munmap() call did not
yield anywhere near the performance gain I expected. Local machine had
10% gain while testing machine only showed 2% gain. I am still not clear
why it happened and the two best theories I have are
1) Variance due to stragglers (that this change addresses)
2) munmap() is either too fast or too slow relative to the other
   instructions compared to our local machine. I don't know which. We'll
   have to use adaptive tuning, but that's in a different change.
This commit is contained in:
Vemana 2024-01-20 02:17:55 +05:30 committed by GitHub
parent 144a6af164
commit 6e3893c6a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -171,7 +171,7 @@ public class CalculateAverage_vemana {
int chunkSizeBits = 20; int chunkSizeBits = 20;
// For the last commonChunkFraction fraction of total work, use smaller chunk sizes // For the last commonChunkFraction fraction of total work, use smaller chunk sizes
double commonChunkFraction = 0; double commonChunkFraction = 0.03;
// Use commonChunkSizeBits for the small-chunk size // Use commonChunkSizeBits for the small-chunk size
int commonChunkSizeBits = 18; int commonChunkSizeBits = 18;
@ -181,11 +181,17 @@ public class CalculateAverage_vemana {
int minReservedBytesAtFileTail = 9; int minReservedBytesAtFileTail = 9;
int nThreads = -1;
String inputFile = "measurements.txt"; String inputFile = "measurements.txt";
double munmapFraction = 0.03;
boolean fakeAdvance = false;
for (String arg : args) { for (String arg : args) {
String key = arg.substring(0, arg.indexOf('=')); String key = arg.substring(0, arg.indexOf('=')).trim();
String value = arg.substring(key.length() + 1); String value = arg.substring(key.length() + 1).trim();
switch (key) { switch (key) {
case "chunkSizeBits": case "chunkSizeBits":
chunkSizeBits = Integer.parseInt(value); chunkSizeBits = Integer.parseInt(value);
@ -202,6 +208,15 @@ public class CalculateAverage_vemana {
case "inputfile": case "inputfile":
inputFile = value; inputFile = value;
break; break;
case "munmapFraction":
munmapFraction = Double.parseDouble(value);
break;
case "fakeAdvance":
fakeAdvance = Boolean.parseBoolean(value);
break;
case "nThreads":
nThreads = Integer.parseInt(value);
break;
default: default:
throw new IllegalArgumentException("Unknown argument: " + arg); throw new IllegalArgumentException("Unknown argument: " + arg);
} }
@ -218,14 +233,17 @@ public class CalculateAverage_vemana {
System.out.println( System.out.println(
new Runner( new Runner(
Path.of(inputFile), Path.of(inputFile),
nThreads,
chunkSizeBits, chunkSizeBits,
commonChunkFraction, commonChunkFraction,
commonChunkSizeBits, commonChunkSizeBits,
hashtableSizeBits, hashtableSizeBits,
minReservedBytesAtFileTail) minReservedBytesAtFileTail,
munmapFraction,
fakeAdvance)
.getSummaryStatistics()); .getSummaryStatistics());
Tracing.recordEvent("After printing result"); Tracing.recordEvent("Final result printed");
} }
public record AggregateResult(Map<String, Stat> tempStats) { public record AggregateResult(Map<String, Stat> tempStats) {
@ -286,8 +304,8 @@ public class CalculateAverage_vemana {
bufferEnd = bufferStart = -1; bufferEnd = bufferStart = -1;
} }
public void close(int shardIdx) { public void close(String closerId, int shardIdx) {
Tracing.recordWorkStart("cleaner", shardIdx); Tracing.recordWorkStart(closerId, shardIdx);
if (byteBuffer != null) { if (byteBuffer != null) {
unclosedBuffers.add(byteBuffer); unclosedBuffers.add(byteBuffer);
} }
@ -297,7 +315,7 @@ public class CalculateAverage_vemana {
unclosedBuffers.clear(); unclosedBuffers.clear();
bufferEnd = bufferStart = -1; bufferEnd = bufferStart = -1;
byteBuffer = null; byteBuffer = null;
Tracing.recordWorkEnd("cleaner", shardIdx); Tracing.recordWorkEnd(closerId, shardIdx);
} }
public void setRange(long rangeStart, long rangeEnd) { public void setRange(long rangeStart, long rangeEnd) {
@ -383,7 +401,7 @@ public class CalculateAverage_vemana {
public interface LazyShardQueue { public interface LazyShardQueue {
void close(int shardIdx); void close(String closerId, int shardIdx);
Optional<ByteRange> fileTailEndWork(int idx); Optional<ByteRange> fileTailEndWork(int idx);
@ -415,37 +433,48 @@ public class CalculateAverage_vemana {
private final double commonChunkFraction; private final double commonChunkFraction;
private final int commonChunkSizeBits; private final int commonChunkSizeBits;
private final boolean fakeAdvance;
private final int hashtableSizeBits; private final int hashtableSizeBits;
private final Path inputFile; private final Path inputFile;
private final int minReservedBytesAtFileTail; private final int minReservedBytesAtFileTail;
private final double munmapFraction;
private final int nThreads;
private final int shardSizeBits; private final int shardSizeBits;
public Runner( public Runner(
Path inputFile, Path inputFile,
int nThreads,
int chunkSizeBits, int chunkSizeBits,
double commonChunkFraction, double commonChunkFraction,
int commonChunkSizeBits, int commonChunkSizeBits,
int hashtableSizeBits, int hashtableSizeBits,
int minReservedBytesAtFileTail) { int minReservedBytesAtFileTail,
double munmapFraction,
boolean fakeAdvance) {
this.inputFile = inputFile; this.inputFile = inputFile;
this.nThreads = nThreads;
this.shardSizeBits = chunkSizeBits; this.shardSizeBits = chunkSizeBits;
this.commonChunkFraction = commonChunkFraction; this.commonChunkFraction = commonChunkFraction;
this.commonChunkSizeBits = commonChunkSizeBits; this.commonChunkSizeBits = commonChunkSizeBits;
this.hashtableSizeBits = hashtableSizeBits; this.hashtableSizeBits = hashtableSizeBits;
this.minReservedBytesAtFileTail = minReservedBytesAtFileTail; this.minReservedBytesAtFileTail = minReservedBytesAtFileTail;
this.munmapFraction = munmapFraction;
this.fakeAdvance = fakeAdvance;
} }
AggregateResult getSummaryStatistics() throws Exception { AggregateResult getSummaryStatistics() throws Exception {
int nThreads = Runtime.getRuntime().availableProcessors(); int nThreads = this.nThreads < 0 ? Runtime.getRuntime().availableProcessors() : this.nThreads;
LazyShardQueue shardQueue = new SerialLazyShardQueue( LazyShardQueue shardQueue = new SerialLazyShardQueue(
1L << shardSizeBits, 1L << shardSizeBits,
inputFile, inputFile,
nThreads, nThreads,
commonChunkFraction, commonChunkFraction,
commonChunkSizeBits, commonChunkSizeBits,
minReservedBytesAtFileTail); minReservedBytesAtFileTail,
munmapFraction,
fakeAdvance);
List<Future<AggregateResult>> results = new ArrayList<>();
ExecutorService executorService = Executors.newFixedThreadPool( ExecutorService executorService = Executors.newFixedThreadPool(
nThreads, nThreads,
runnable -> { runnable -> {
@ -454,42 +483,56 @@ public class CalculateAverage_vemana {
return thread; return thread;
}); });
List<Future<AggregateResult>> results = new ArrayList<>();
for (int i = 0; i < nThreads; i++) { for (int i = 0; i < nThreads; i++) {
final int shardIdx = i; final int shardIdx = i;
final Callable<AggregateResult> callable = () -> { final Callable<AggregateResult> callable = () -> {
Tracing.recordWorkStart("shard", shardIdx); Tracing.recordWorkStart("Shard", shardIdx);
AggregateResult result = new ShardProcessor(shardQueue, hashtableSizeBits, shardIdx).processShard(); AggregateResult result = new ShardProcessor(shardQueue, hashtableSizeBits, shardIdx).processShard();
Tracing.recordWorkEnd("shard", shardIdx); Tracing.recordWorkEnd("Shard", shardIdx);
return result; return result;
}; };
results.add(executorService.submit(callable)); results.add(executorService.submit(callable));
} }
Tracing.recordEvent("Basic push time"); Tracing.recordEvent("Basic push time");
AggregateResult result = executorService.submit(() -> merge(results)).get(); // This particular sequence of Futures is so that both merge and munmap() can work as shards
// finish their computation without blocking on the entire set of shards to complete. In
// particular, munmap() doesn't need to wait on merge.
// First, submit a task to merge the results and then submit a task to cleanup bytebuffers
// from completed shards.
Future<AggregateResult> resultFutures = executorService.submit(() -> merge(results));
// Note that munmap() is serial and not parallel and hence we use just one thread.
executorService.submit(() -> closeByteBuffers(results, shardQueue));
AggregateResult result = resultFutures.get();
Tracing.recordEvent("Merge results received"); Tracing.recordEvent("Merge results received");
// Note that munmap() is serial and not parallel Tracing.recordEvent("About to shutdown executor and wait");
executorService.submit(
() -> {
for (int i = 0; i < nThreads; i++) {
shardQueue.close(i);
}
});
Tracing.recordEvent("Waiting for executor shutdown");
executorService.shutdown(); executorService.shutdown();
executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS); executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS);
Tracing.recordEvent("Executor terminated"); Tracing.recordEvent("Executor terminated");
Tracing.analyzeWorkThreads("cleaner", nThreads);
Tracing.recordEvent("After cleaner finish printed");
Tracing.analyzeWorkThreads(nThreads);
return result; return result;
} }
private void closeByteBuffers(
List<Future<AggregateResult>> results, LazyShardQueue shardQueue) {
int n = results.size();
boolean[] isDone = new boolean[n];
int remaining = results.size();
while (remaining > 0) {
for (int i = 0; i < n; i++) {
if (!isDone[i] && results.get(i).isDone()) {
remaining--;
isDone[i] = true;
shardQueue.close("Ending Cleaner", i);
}
}
}
}
private AggregateResult merge(List<Future<AggregateResult>> results) private AggregateResult merge(List<Future<AggregateResult>> results)
throws ExecutionException, InterruptedException { throws ExecutionException, InterruptedException {
Tracing.recordEvent("Merge start time"); Tracing.recordEvent("Merge start time");
@ -516,7 +559,6 @@ public class CalculateAverage_vemana {
} }
} }
Tracing.recordEvent("Merge end time"); Tracing.recordEvent("Merge end time");
Tracing.analyzeWorkThreads("shard", results.size());
return new AggregateResult(output); return new AggregateResult(output);
} }
} }
@ -532,6 +574,7 @@ public class CalculateAverage_vemana {
private final long commonChunkSize; private final long commonChunkSize;
private final AtomicLong commonPool; private final AtomicLong commonPool;
private final long effectiveFileSize; private final long effectiveFileSize;
private final boolean fakeAdvance;
private final long fileSize; private final long fileSize;
private final long[] perThreadData; private final long[] perThreadData;
private final RandomAccessFile raf; private final RandomAccessFile raf;
@ -543,8 +586,11 @@ public class CalculateAverage_vemana {
int shards, int shards,
double commonChunkFraction, double commonChunkFraction,
int commonChunkSizeBits, int commonChunkSizeBits,
int fileTailReservedBytes) int fileTailReservedBytes,
double munmapFraction,
boolean fakeAdvance)
throws IOException { throws IOException {
this.fakeAdvance = fakeAdvance;
Checks.checkArg(commonChunkFraction < 0.9 && commonChunkFraction >= 0); Checks.checkArg(commonChunkFraction < 0.9 && commonChunkFraction >= 0);
Checks.checkArg(fileTailReservedBytes >= 0); Checks.checkArg(fileTailReservedBytes >= 0);
this.raf = new RandomAccessFile(filePath.toFile(), "r"); this.raf = new RandomAccessFile(filePath.toFile(), "r");
@ -580,8 +626,8 @@ public class CalculateAverage_vemana {
// its work, where R = relative speed of unmap compared to the computation. // its work, where R = relative speed of unmap compared to the computation.
// For our problem, R ~ 75 because unmap unmaps 30GB/sec (but, it is serial) while // For our problem, R ~ 75 because unmap unmaps 30GB/sec (but, it is serial) while
// cores go through data at the rate of 400MB/sec. // cores go through data at the rate of 400MB/sec.
perThreadData[pos + 3] = (long) (currentChunks * (0.03 * (shards - i))); perThreadData[pos + 3] = (long) (currentChunks * (munmapFraction * (shards - i)));
perThreadData[pos + 4] = 1; perThreadData[pos + 4] = 1; // true iff munmap() hasn't been triggered yet
currentStart += currentChunks * chunkSize; currentStart += currentChunks * chunkSize;
remainingChunks -= currentChunks; remainingChunks -= currentChunks;
} }
@ -596,8 +642,8 @@ public class CalculateAverage_vemana {
} }
@Override @Override
public void close(int shardIdx) { public void close(String closerId, int shardIdx) {
byteRanges[shardIdx << 4].close(shardIdx); byteRanges[shardIdx << 4].close(closerId, shardIdx);
} }
@Override @Override
@ -616,14 +662,18 @@ public class CalculateAverage_vemana {
public ByteRange take(int shardIdx) { public ByteRange take(int shardIdx) {
// Try for thread local range // Try for thread local range
final int pos = shardIdx << 4; final int pos = shardIdx << 4;
long rangeStart = perThreadData[pos]; final long rangeStart;
final long chunkEnd = perThreadData[pos + 1];
final long rangeEnd; final long rangeEnd;
if (rangeStart < chunkEnd) { if (perThreadData[pos + 2] >= 1) {
rangeStart = perThreadData[pos];
rangeEnd = rangeStart + chunkSize; rangeEnd = rangeStart + chunkSize;
perThreadData[pos] = rangeEnd; // Don't do this in the if-check; it causes negative values that trigger intermediate
// cleanup
perThreadData[pos + 2]--; perThreadData[pos + 2]--;
if (!fakeAdvance) {
perThreadData[pos] = rangeEnd;
}
} }
else { else {
rangeStart = commonPool.getAndAdd(commonChunkSize); rangeStart = commonPool.getAndAdd(commonChunkSize);
@ -634,8 +684,8 @@ public class CalculateAverage_vemana {
rangeEnd = rangeStart + commonChunkSize; rangeEnd = rangeStart + commonChunkSize;
} }
if (perThreadData[pos + 2] <= perThreadData[pos + 3] && perThreadData[pos + 4] > 0) { if (perThreadData[pos + 2] < perThreadData[pos + 3] && perThreadData[pos + 4] > 0) {
if (attemptClose(shardIdx)) { if (attemptIntermediateClose(shardIdx)) {
perThreadData[pos + 4]--; perThreadData[pos + 4]--;
} }
} }
@ -645,9 +695,9 @@ public class CalculateAverage_vemana {
return chunk; return chunk;
} }
private boolean attemptClose(int shardIdx) { private boolean attemptIntermediateClose(int shardIdx) {
if (seqLock.acquire()) { if (seqLock.acquire()) {
byteRanges[shardIdx << 4].close(shardIdx); close("Intermediate Cleaner", shardIdx);
seqLock.release(); seqLock.release();
return true; return true;
} }
@ -964,12 +1014,22 @@ public class CalculateAverage_vemana {
static class Tracing { static class Tracing {
private static final long[] cleanerTimes = new long[1 << 6 << 1]; private static final Map<String, ThreadTimingsArray> knownWorkThreadEvents;
private static final long[] threadTimes = new long[1 << 6 << 1];
private static long startTime; private static long startTime;
static void analyzeWorkThreads(String id, int nThreads) { static {
printTimingsAnalysis(id + " Stats", nThreads, timingsArray(id)); // Maintain the ordering to be chronological in execution
// Map.of(..) screws up ordering
knownWorkThreadEvents = new LinkedHashMap<>();
for (String id : List.of("Shard", "Intermediate Cleaner", "Ending Cleaner")) {
knownWorkThreadEvents.put(id, new ThreadTimingsArray(id, 1 << 6 << 1));
}
}
static void analyzeWorkThreads(int nThreads) {
for (ThreadTimingsArray array : knownWorkThreadEvents.values()) {
errPrint(array.analyze(nThreads));
}
} }
static void recordAppStart() { static void recordAppStart() {
@ -981,11 +1041,11 @@ public class CalculateAverage_vemana {
} }
static void recordWorkEnd(String id, int threadId) { static void recordWorkEnd(String id, int threadId) {
timingsArray(id)[2 * threadId + 1] = System.nanoTime(); knownWorkThreadEvents.get(id).recordEnd(threadId);
} }
static void recordWorkStart(String id, int threadId) { static void recordWorkStart(String id, int threadId) {
timingsArray(id)[2 * threadId] = System.nanoTime(); knownWorkThreadEvents.get(id).recordStart(threadId);
} }
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
@ -998,7 +1058,29 @@ public class CalculateAverage_vemana {
errPrint(STR."\{message} = \{(nanoTime - startTime) / 1_000_000}ms"); errPrint(STR."\{message} = \{(nanoTime - startTime) / 1_000_000}ms");
} }
private static void printTimingsAnalysis(String header, int nThreads, long[] timestamps) { public static class ThreadTimingsArray {
private static String toString(long[] array) {
return Arrays.stream(array)
.map(x -> x < 0 ? -1 : x)
.mapToObj(x -> String.format("%6d", x))
.collect(Collectors.joining(", ", "[ ", " ]"));
}
private final String id;
private final long[] timestamps;
private boolean hasData = false;
public ThreadTimingsArray(String id, int maxSize) {
this.timestamps = new long[maxSize];
this.id = id;
}
public String analyze(int nThreads) {
if (!hasData) {
return "%s has no thread timings data".formatted(id);
}
Checks.checkArg(nThreads <= timestamps.length);
long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE; long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE;
long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE; long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE;
long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE; long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE;
@ -1015,40 +1097,39 @@ public class CalculateAverage_vemana {
minDuration = Math.min(minDuration, durationNs); minDuration = Math.min(minDuration, durationNs);
maxDuration = Math.max(maxDuration, durationNs); maxDuration = Math.max(maxDuration, durationNs);
minBegin = Math.min(minBegin, timestamps[2 * i]); minBegin = Math.min(minBegin, timestamps[2 * i] - startTime);
maxBegin = Math.max(maxBegin, timestamps[2 * i]); maxBegin = Math.max(maxBegin, timestamps[2 * i] - startTime);
maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1]); maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1] - startTime);
minCompletion = Math.min(minCompletion, timestamps[2 * i + 1]); minCompletion = Math.min(minCompletion, timestamps[2 * i + 1] - startTime);
} }
errPrint( return STR."""
STR."""
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
\{header} \{id} Stats
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
Max duration = \{maxDuration / 1_000_000} ms Max duration = \{maxDuration / 1_000_000} ms
Min duration = \{minDuration / 1_000_000} ms Min duration = \{minDuration / 1_000_000} ms
Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms [\{maxCompletion / 1_000_000} - \{minBegin / 1_000_000} ]
Completion Timespan[max(end)-min(end)] = \{(maxCompletion - minCompletion) / 1_000_000} ms Completion Timespan[max(end)-min(end)] = \{(maxCompletion - minCompletion) / 1_000_000} ms
Begin Timespan[max(begin)-min(begin)] = \{(maxBegin - minBegin) / 1_000_000} ms Begin Timespan[max(begin)-min(begin)] = \{(maxBegin - minBegin) / 1_000_000} ms
Durations = \{toString(durationsMs)} in ms Average Duration = \{Arrays.stream(durationsMs)
Begin Timestamps = \{toString(beginMs)} in ms .average()
Completion Timestamps = \{toString(completionsMs)} in ms .getAsDouble()} ms
"""); Durations = \{toString(durationsMs)} ms
Begin Timestamps = \{toString(beginMs)} ms
Completion Timestamps = \{toString(completionsMs)} ms
""";
} }
private static long[] timingsArray(String id) { public void recordEnd(int idx) {
return switch (id) { timestamps[2 * idx + 1] = System.nanoTime();
case "cleaner" -> cleanerTimes; hasData = true;
case "shard" -> threadTimes;
default -> throw new RuntimeException("");
};
} }
private static String toString(long[] array) { public void recordStart(int idx) {
return Arrays.stream(array) timestamps[2 * idx] = System.nanoTime();
.mapToObj(x -> String.format("%6d", x)) hasData = true;
.collect(Collectors.joining(", ", "[ ", " ]")); }
} }
} }
} }