Reduce variance by (1) Using common chunks at the end (2) Busy looping (#486)

on automatic closing of ByteBuffers.. previously, a straggler could hold
up closing the ByteBuffers.

Also
- Improve Tracing code
- Parametrize additional options to aid in tuning

Our previous PR was surprising; parallelizing munmap() call did not
yield anywhere near the performance gain I expected. Local machine had
10% gain while testing machine only showed 2% gain. I am still not clear
why it happened and the two best theories I have are
1) Variance due to stragglers (that this change addresses)
2) munmap() is either too fast or too slow relative to the other
   instructions compared to our local machine. I don't know which. We'll
   have to use adaptive tuning, but that's in a different change.
This commit is contained in:
Vemana 2024-01-20 02:17:55 +05:30 committed by GitHub
parent 144a6af164
commit 6e3893c6a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -171,7 +171,7 @@ public class CalculateAverage_vemana {
int chunkSizeBits = 20;
// For the last commonChunkFraction fraction of total work, use smaller chunk sizes
double commonChunkFraction = 0;
double commonChunkFraction = 0.03;
// Use commonChunkSizeBits for the small-chunk size
int commonChunkSizeBits = 18;
@ -181,11 +181,17 @@ public class CalculateAverage_vemana {
int minReservedBytesAtFileTail = 9;
int nThreads = -1;
String inputFile = "measurements.txt";
double munmapFraction = 0.03;
boolean fakeAdvance = false;
for (String arg : args) {
String key = arg.substring(0, arg.indexOf('='));
String value = arg.substring(key.length() + 1);
String key = arg.substring(0, arg.indexOf('=')).trim();
String value = arg.substring(key.length() + 1).trim();
switch (key) {
case "chunkSizeBits":
chunkSizeBits = Integer.parseInt(value);
@ -202,6 +208,15 @@ public class CalculateAverage_vemana {
case "inputfile":
inputFile = value;
break;
case "munmapFraction":
munmapFraction = Double.parseDouble(value);
break;
case "fakeAdvance":
fakeAdvance = Boolean.parseBoolean(value);
break;
case "nThreads":
nThreads = Integer.parseInt(value);
break;
default:
throw new IllegalArgumentException("Unknown argument: " + arg);
}
@ -218,14 +233,17 @@ public class CalculateAverage_vemana {
System.out.println(
new Runner(
Path.of(inputFile),
nThreads,
chunkSizeBits,
commonChunkFraction,
commonChunkSizeBits,
hashtableSizeBits,
minReservedBytesAtFileTail)
minReservedBytesAtFileTail,
munmapFraction,
fakeAdvance)
.getSummaryStatistics());
Tracing.recordEvent("After printing result");
Tracing.recordEvent("Final result printed");
}
public record AggregateResult(Map<String, Stat> tempStats) {
@ -286,8 +304,8 @@ public class CalculateAverage_vemana {
bufferEnd = bufferStart = -1;
}
public void close(int shardIdx) {
Tracing.recordWorkStart("cleaner", shardIdx);
public void close(String closerId, int shardIdx) {
Tracing.recordWorkStart(closerId, shardIdx);
if (byteBuffer != null) {
unclosedBuffers.add(byteBuffer);
}
@ -297,7 +315,7 @@ public class CalculateAverage_vemana {
unclosedBuffers.clear();
bufferEnd = bufferStart = -1;
byteBuffer = null;
Tracing.recordWorkEnd("cleaner", shardIdx);
Tracing.recordWorkEnd(closerId, shardIdx);
}
public void setRange(long rangeStart, long rangeEnd) {
@ -383,7 +401,7 @@ public class CalculateAverage_vemana {
public interface LazyShardQueue {
void close(int shardIdx);
void close(String closerId, int shardIdx);
Optional<ByteRange> fileTailEndWork(int idx);
@ -415,37 +433,48 @@ public class CalculateAverage_vemana {
private final double commonChunkFraction;
private final int commonChunkSizeBits;
private final boolean fakeAdvance;
private final int hashtableSizeBits;
private final Path inputFile;
private final int minReservedBytesAtFileTail;
private final double munmapFraction;
private final int nThreads;
private final int shardSizeBits;
public Runner(
Path inputFile,
int nThreads,
int chunkSizeBits,
double commonChunkFraction,
int commonChunkSizeBits,
int hashtableSizeBits,
int minReservedBytesAtFileTail) {
int minReservedBytesAtFileTail,
double munmapFraction,
boolean fakeAdvance) {
this.inputFile = inputFile;
this.nThreads = nThreads;
this.shardSizeBits = chunkSizeBits;
this.commonChunkFraction = commonChunkFraction;
this.commonChunkSizeBits = commonChunkSizeBits;
this.hashtableSizeBits = hashtableSizeBits;
this.minReservedBytesAtFileTail = minReservedBytesAtFileTail;
this.munmapFraction = munmapFraction;
this.fakeAdvance = fakeAdvance;
}
AggregateResult getSummaryStatistics() throws Exception {
int nThreads = Runtime.getRuntime().availableProcessors();
int nThreads = this.nThreads < 0 ? Runtime.getRuntime().availableProcessors() : this.nThreads;
LazyShardQueue shardQueue = new SerialLazyShardQueue(
1L << shardSizeBits,
inputFile,
nThreads,
commonChunkFraction,
commonChunkSizeBits,
minReservedBytesAtFileTail);
minReservedBytesAtFileTail,
munmapFraction,
fakeAdvance);
List<Future<AggregateResult>> results = new ArrayList<>();
ExecutorService executorService = Executors.newFixedThreadPool(
nThreads,
runnable -> {
@ -454,42 +483,56 @@ public class CalculateAverage_vemana {
return thread;
});
List<Future<AggregateResult>> results = new ArrayList<>();
for (int i = 0; i < nThreads; i++) {
final int shardIdx = i;
final Callable<AggregateResult> callable = () -> {
Tracing.recordWorkStart("shard", shardIdx);
Tracing.recordWorkStart("Shard", shardIdx);
AggregateResult result = new ShardProcessor(shardQueue, hashtableSizeBits, shardIdx).processShard();
Tracing.recordWorkEnd("shard", shardIdx);
Tracing.recordWorkEnd("Shard", shardIdx);
return result;
};
results.add(executorService.submit(callable));
}
Tracing.recordEvent("Basic push time");
AggregateResult result = executorService.submit(() -> merge(results)).get();
// This particular sequence of Futures is so that both merge and munmap() can work as shards
// finish their computation without blocking on the entire set of shards to complete. In
// particular, munmap() doesn't need to wait on merge.
// First, submit a task to merge the results and then submit a task to cleanup bytebuffers
// from completed shards.
Future<AggregateResult> resultFutures = executorService.submit(() -> merge(results));
// Note that munmap() is serial and not parallel and hence we use just one thread.
executorService.submit(() -> closeByteBuffers(results, shardQueue));
AggregateResult result = resultFutures.get();
Tracing.recordEvent("Merge results received");
// Note that munmap() is serial and not parallel
executorService.submit(
() -> {
for (int i = 0; i < nThreads; i++) {
shardQueue.close(i);
}
});
Tracing.recordEvent("Waiting for executor shutdown");
Tracing.recordEvent("About to shutdown executor and wait");
executorService.shutdown();
executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS);
Tracing.recordEvent("Executor terminated");
Tracing.analyzeWorkThreads("cleaner", nThreads);
Tracing.recordEvent("After cleaner finish printed");
Tracing.analyzeWorkThreads(nThreads);
return result;
}
private void closeByteBuffers(
List<Future<AggregateResult>> results, LazyShardQueue shardQueue) {
int n = results.size();
boolean[] isDone = new boolean[n];
int remaining = results.size();
while (remaining > 0) {
for (int i = 0; i < n; i++) {
if (!isDone[i] && results.get(i).isDone()) {
remaining--;
isDone[i] = true;
shardQueue.close("Ending Cleaner", i);
}
}
}
}
private AggregateResult merge(List<Future<AggregateResult>> results)
throws ExecutionException, InterruptedException {
Tracing.recordEvent("Merge start time");
@ -516,7 +559,6 @@ public class CalculateAverage_vemana {
}
}
Tracing.recordEvent("Merge end time");
Tracing.analyzeWorkThreads("shard", results.size());
return new AggregateResult(output);
}
}
@ -532,6 +574,7 @@ public class CalculateAverage_vemana {
private final long commonChunkSize;
private final AtomicLong commonPool;
private final long effectiveFileSize;
private final boolean fakeAdvance;
private final long fileSize;
private final long[] perThreadData;
private final RandomAccessFile raf;
@ -543,8 +586,11 @@ public class CalculateAverage_vemana {
int shards,
double commonChunkFraction,
int commonChunkSizeBits,
int fileTailReservedBytes)
int fileTailReservedBytes,
double munmapFraction,
boolean fakeAdvance)
throws IOException {
this.fakeAdvance = fakeAdvance;
Checks.checkArg(commonChunkFraction < 0.9 && commonChunkFraction >= 0);
Checks.checkArg(fileTailReservedBytes >= 0);
this.raf = new RandomAccessFile(filePath.toFile(), "r");
@ -580,8 +626,8 @@ public class CalculateAverage_vemana {
// its work, where R = relative speed of unmap compared to the computation.
// For our problem, R ~ 75 because unmap unmaps 30GB/sec (but, it is serial) while
// cores go through data at the rate of 400MB/sec.
perThreadData[pos + 3] = (long) (currentChunks * (0.03 * (shards - i)));
perThreadData[pos + 4] = 1;
perThreadData[pos + 3] = (long) (currentChunks * (munmapFraction * (shards - i)));
perThreadData[pos + 4] = 1; // true iff munmap() hasn't been triggered yet
currentStart += currentChunks * chunkSize;
remainingChunks -= currentChunks;
}
@ -596,8 +642,8 @@ public class CalculateAverage_vemana {
}
@Override
public void close(int shardIdx) {
byteRanges[shardIdx << 4].close(shardIdx);
public void close(String closerId, int shardIdx) {
byteRanges[shardIdx << 4].close(closerId, shardIdx);
}
@Override
@ -616,14 +662,18 @@ public class CalculateAverage_vemana {
public ByteRange take(int shardIdx) {
// Try for thread local range
final int pos = shardIdx << 4;
long rangeStart = perThreadData[pos];
final long chunkEnd = perThreadData[pos + 1];
final long rangeStart;
final long rangeEnd;
if (rangeStart < chunkEnd) {
if (perThreadData[pos + 2] >= 1) {
rangeStart = perThreadData[pos];
rangeEnd = rangeStart + chunkSize;
perThreadData[pos] = rangeEnd;
// Don't do this in the if-check; it causes negative values that trigger intermediate
// cleanup
perThreadData[pos + 2]--;
if (!fakeAdvance) {
perThreadData[pos] = rangeEnd;
}
}
else {
rangeStart = commonPool.getAndAdd(commonChunkSize);
@ -634,8 +684,8 @@ public class CalculateAverage_vemana {
rangeEnd = rangeStart + commonChunkSize;
}
if (perThreadData[pos + 2] <= perThreadData[pos + 3] && perThreadData[pos + 4] > 0) {
if (attemptClose(shardIdx)) {
if (perThreadData[pos + 2] < perThreadData[pos + 3] && perThreadData[pos + 4] > 0) {
if (attemptIntermediateClose(shardIdx)) {
perThreadData[pos + 4]--;
}
}
@ -645,9 +695,9 @@ public class CalculateAverage_vemana {
return chunk;
}
private boolean attemptClose(int shardIdx) {
private boolean attemptIntermediateClose(int shardIdx) {
if (seqLock.acquire()) {
byteRanges[shardIdx << 4].close(shardIdx);
close("Intermediate Cleaner", shardIdx);
seqLock.release();
return true;
}
@ -964,12 +1014,22 @@ public class CalculateAverage_vemana {
static class Tracing {
private static final long[] cleanerTimes = new long[1 << 6 << 1];
private static final long[] threadTimes = new long[1 << 6 << 1];
private static final Map<String, ThreadTimingsArray> knownWorkThreadEvents;
private static long startTime;
static void analyzeWorkThreads(String id, int nThreads) {
printTimingsAnalysis(id + " Stats", nThreads, timingsArray(id));
static {
// Maintain the ordering to be chronological in execution
// Map.of(..) screws up ordering
knownWorkThreadEvents = new LinkedHashMap<>();
for (String id : List.of("Shard", "Intermediate Cleaner", "Ending Cleaner")) {
knownWorkThreadEvents.put(id, new ThreadTimingsArray(id, 1 << 6 << 1));
}
}
static void analyzeWorkThreads(int nThreads) {
for (ThreadTimingsArray array : knownWorkThreadEvents.values()) {
errPrint(array.analyze(nThreads));
}
}
static void recordAppStart() {
@ -981,11 +1041,11 @@ public class CalculateAverage_vemana {
}
static void recordWorkEnd(String id, int threadId) {
timingsArray(id)[2 * threadId + 1] = System.nanoTime();
knownWorkThreadEvents.get(id).recordEnd(threadId);
}
static void recordWorkStart(String id, int threadId) {
timingsArray(id)[2 * threadId] = System.nanoTime();
knownWorkThreadEvents.get(id).recordStart(threadId);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -998,57 +1058,78 @@ public class CalculateAverage_vemana {
errPrint(STR."\{message} = \{(nanoTime - startTime) / 1_000_000}ms");
}
private static void printTimingsAnalysis(String header, int nThreads, long[] timestamps) {
long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE;
long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE;
long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE;
public static class ThreadTimingsArray {
long[] durationsMs = new long[nThreads];
long[] completionsMs = new long[nThreads];
long[] beginMs = new long[nThreads];
for (int i = 0; i < nThreads; i++) {
long durationNs = timestamps[2 * i + 1] - timestamps[2 * i];
durationsMs[i] = durationNs / 1_000_000;
completionsMs[i] = (timestamps[2 * i + 1] - startTime) / 1_000_000;
beginMs[i] = (timestamps[2 * i] - startTime) / 1_000_000;
private static String toString(long[] array) {
return Arrays.stream(array)
.map(x -> x < 0 ? -1 : x)
.mapToObj(x -> String.format("%6d", x))
.collect(Collectors.joining(", ", "[ ", " ]"));
}
minDuration = Math.min(minDuration, durationNs);
maxDuration = Math.max(maxDuration, durationNs);
private final String id;
private final long[] timestamps;
private boolean hasData = false;
minBegin = Math.min(minBegin, timestamps[2 * i]);
maxBegin = Math.max(maxBegin, timestamps[2 * i]);
public ThreadTimingsArray(String id, int maxSize) {
this.timestamps = new long[maxSize];
this.id = id;
}
maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1]);
minCompletion = Math.min(minCompletion, timestamps[2 * i + 1]);
}
errPrint(
STR."""
public String analyze(int nThreads) {
if (!hasData) {
return "%s has no thread timings data".formatted(id);
}
Checks.checkArg(nThreads <= timestamps.length);
long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE;
long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE;
long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE;
long[] durationsMs = new long[nThreads];
long[] completionsMs = new long[nThreads];
long[] beginMs = new long[nThreads];
for (int i = 0; i < nThreads; i++) {
long durationNs = timestamps[2 * i + 1] - timestamps[2 * i];
durationsMs[i] = durationNs / 1_000_000;
completionsMs[i] = (timestamps[2 * i + 1] - startTime) / 1_000_000;
beginMs[i] = (timestamps[2 * i] - startTime) / 1_000_000;
minDuration = Math.min(minDuration, durationNs);
maxDuration = Math.max(maxDuration, durationNs);
minBegin = Math.min(minBegin, timestamps[2 * i] - startTime);
maxBegin = Math.max(maxBegin, timestamps[2 * i] - startTime);
maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1] - startTime);
minCompletion = Math.min(minCompletion, timestamps[2 * i + 1] - startTime);
}
return STR."""
-------------------------------------------------------------------------------------------
\{header}
\{id} Stats
-------------------------------------------------------------------------------------------
Max duration = \{maxDuration / 1_000_000} ms
Min duration = \{minDuration / 1_000_000} ms
Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms
Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms [\{maxCompletion / 1_000_000} - \{minBegin / 1_000_000} ]
Completion Timespan[max(end)-min(end)] = \{(maxCompletion - minCompletion) / 1_000_000} ms
Begin Timespan[max(begin)-min(begin)] = \{(maxBegin - minBegin) / 1_000_000} ms
Durations = \{toString(durationsMs)} in ms
Begin Timestamps = \{toString(beginMs)} in ms
Completion Timestamps = \{toString(completionsMs)} in ms
""");
}
Average Duration = \{Arrays.stream(durationsMs)
.average()
.getAsDouble()} ms
Durations = \{toString(durationsMs)} ms
Begin Timestamps = \{toString(beginMs)} ms
Completion Timestamps = \{toString(completionsMs)} ms
""";
}
private static long[] timingsArray(String id) {
return switch (id) {
case "cleaner" -> cleanerTimes;
case "shard" -> threadTimes;
default -> throw new RuntimeException("");
};
}
public void recordEnd(int idx) {
timestamps[2 * idx + 1] = System.nanoTime();
hasData = true;
}
private static String toString(long[] array) {
return Arrays.stream(array)
.mapToObj(x -> String.format("%6d", x))
.collect(Collectors.joining(", ", "[ ", " ]"));
public void recordStart(int idx) {
timestamps[2 * idx] = System.nanoTime();
hasData = true;
}
}
}
}