Reduce variance by (1) Using common chunks at the end (2) Busy looping (#486)
on automatic closing of ByteBuffers.. previously, a straggler could hold up closing the ByteBuffers. Also - Improve Tracing code - Parametrize additional options to aid in tuning Our previous PR was surprising; parallelizing munmap() call did not yield anywhere near the performance gain I expected. Local machine had 10% gain while testing machine only showed 2% gain. I am still not clear why it happened and the two best theories I have are 1) Variance due to stragglers (that this change addresses) 2) munmap() is either too fast or too slow relative to the other instructions compared to our local machine. I don't know which. We'll have to use adaptive tuning, but that's in a different change.
This commit is contained in:
parent
144a6af164
commit
6e3893c6a6
@ -171,7 +171,7 @@ public class CalculateAverage_vemana {
|
||||
int chunkSizeBits = 20;
|
||||
|
||||
// For the last commonChunkFraction fraction of total work, use smaller chunk sizes
|
||||
double commonChunkFraction = 0;
|
||||
double commonChunkFraction = 0.03;
|
||||
|
||||
// Use commonChunkSizeBits for the small-chunk size
|
||||
int commonChunkSizeBits = 18;
|
||||
@ -181,11 +181,17 @@ public class CalculateAverage_vemana {
|
||||
|
||||
int minReservedBytesAtFileTail = 9;
|
||||
|
||||
int nThreads = -1;
|
||||
|
||||
String inputFile = "measurements.txt";
|
||||
|
||||
double munmapFraction = 0.03;
|
||||
|
||||
boolean fakeAdvance = false;
|
||||
|
||||
for (String arg : args) {
|
||||
String key = arg.substring(0, arg.indexOf('='));
|
||||
String value = arg.substring(key.length() + 1);
|
||||
String key = arg.substring(0, arg.indexOf('=')).trim();
|
||||
String value = arg.substring(key.length() + 1).trim();
|
||||
switch (key) {
|
||||
case "chunkSizeBits":
|
||||
chunkSizeBits = Integer.parseInt(value);
|
||||
@ -202,6 +208,15 @@ public class CalculateAverage_vemana {
|
||||
case "inputfile":
|
||||
inputFile = value;
|
||||
break;
|
||||
case "munmapFraction":
|
||||
munmapFraction = Double.parseDouble(value);
|
||||
break;
|
||||
case "fakeAdvance":
|
||||
fakeAdvance = Boolean.parseBoolean(value);
|
||||
break;
|
||||
case "nThreads":
|
||||
nThreads = Integer.parseInt(value);
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException("Unknown argument: " + arg);
|
||||
}
|
||||
@ -218,14 +233,17 @@ public class CalculateAverage_vemana {
|
||||
System.out.println(
|
||||
new Runner(
|
||||
Path.of(inputFile),
|
||||
nThreads,
|
||||
chunkSizeBits,
|
||||
commonChunkFraction,
|
||||
commonChunkSizeBits,
|
||||
hashtableSizeBits,
|
||||
minReservedBytesAtFileTail)
|
||||
minReservedBytesAtFileTail,
|
||||
munmapFraction,
|
||||
fakeAdvance)
|
||||
.getSummaryStatistics());
|
||||
|
||||
Tracing.recordEvent("After printing result");
|
||||
Tracing.recordEvent("Final result printed");
|
||||
}
|
||||
|
||||
public record AggregateResult(Map<String, Stat> tempStats) {
|
||||
@ -286,8 +304,8 @@ public class CalculateAverage_vemana {
|
||||
bufferEnd = bufferStart = -1;
|
||||
}
|
||||
|
||||
public void close(int shardIdx) {
|
||||
Tracing.recordWorkStart("cleaner", shardIdx);
|
||||
public void close(String closerId, int shardIdx) {
|
||||
Tracing.recordWorkStart(closerId, shardIdx);
|
||||
if (byteBuffer != null) {
|
||||
unclosedBuffers.add(byteBuffer);
|
||||
}
|
||||
@ -297,7 +315,7 @@ public class CalculateAverage_vemana {
|
||||
unclosedBuffers.clear();
|
||||
bufferEnd = bufferStart = -1;
|
||||
byteBuffer = null;
|
||||
Tracing.recordWorkEnd("cleaner", shardIdx);
|
||||
Tracing.recordWorkEnd(closerId, shardIdx);
|
||||
}
|
||||
|
||||
public void setRange(long rangeStart, long rangeEnd) {
|
||||
@ -383,7 +401,7 @@ public class CalculateAverage_vemana {
|
||||
|
||||
public interface LazyShardQueue {
|
||||
|
||||
void close(int shardIdx);
|
||||
void close(String closerId, int shardIdx);
|
||||
|
||||
Optional<ByteRange> fileTailEndWork(int idx);
|
||||
|
||||
@ -415,37 +433,48 @@ public class CalculateAverage_vemana {
|
||||
|
||||
private final double commonChunkFraction;
|
||||
private final int commonChunkSizeBits;
|
||||
private final boolean fakeAdvance;
|
||||
private final int hashtableSizeBits;
|
||||
private final Path inputFile;
|
||||
private final int minReservedBytesAtFileTail;
|
||||
private final double munmapFraction;
|
||||
private final int nThreads;
|
||||
private final int shardSizeBits;
|
||||
|
||||
public Runner(
|
||||
Path inputFile,
|
||||
int nThreads,
|
||||
int chunkSizeBits,
|
||||
double commonChunkFraction,
|
||||
int commonChunkSizeBits,
|
||||
int hashtableSizeBits,
|
||||
int minReservedBytesAtFileTail) {
|
||||
int minReservedBytesAtFileTail,
|
||||
double munmapFraction,
|
||||
boolean fakeAdvance) {
|
||||
this.inputFile = inputFile;
|
||||
this.nThreads = nThreads;
|
||||
this.shardSizeBits = chunkSizeBits;
|
||||
this.commonChunkFraction = commonChunkFraction;
|
||||
this.commonChunkSizeBits = commonChunkSizeBits;
|
||||
this.hashtableSizeBits = hashtableSizeBits;
|
||||
this.minReservedBytesAtFileTail = minReservedBytesAtFileTail;
|
||||
this.munmapFraction = munmapFraction;
|
||||
this.fakeAdvance = fakeAdvance;
|
||||
}
|
||||
|
||||
AggregateResult getSummaryStatistics() throws Exception {
|
||||
int nThreads = Runtime.getRuntime().availableProcessors();
|
||||
int nThreads = this.nThreads < 0 ? Runtime.getRuntime().availableProcessors() : this.nThreads;
|
||||
|
||||
LazyShardQueue shardQueue = new SerialLazyShardQueue(
|
||||
1L << shardSizeBits,
|
||||
inputFile,
|
||||
nThreads,
|
||||
commonChunkFraction,
|
||||
commonChunkSizeBits,
|
||||
minReservedBytesAtFileTail);
|
||||
minReservedBytesAtFileTail,
|
||||
munmapFraction,
|
||||
fakeAdvance);
|
||||
|
||||
List<Future<AggregateResult>> results = new ArrayList<>();
|
||||
ExecutorService executorService = Executors.newFixedThreadPool(
|
||||
nThreads,
|
||||
runnable -> {
|
||||
@ -454,42 +483,56 @@ public class CalculateAverage_vemana {
|
||||
return thread;
|
||||
});
|
||||
|
||||
List<Future<AggregateResult>> results = new ArrayList<>();
|
||||
for (int i = 0; i < nThreads; i++) {
|
||||
final int shardIdx = i;
|
||||
final Callable<AggregateResult> callable = () -> {
|
||||
Tracing.recordWorkStart("shard", shardIdx);
|
||||
Tracing.recordWorkStart("Shard", shardIdx);
|
||||
AggregateResult result = new ShardProcessor(shardQueue, hashtableSizeBits, shardIdx).processShard();
|
||||
Tracing.recordWorkEnd("shard", shardIdx);
|
||||
Tracing.recordWorkEnd("Shard", shardIdx);
|
||||
return result;
|
||||
};
|
||||
results.add(executorService.submit(callable));
|
||||
}
|
||||
Tracing.recordEvent("Basic push time");
|
||||
|
||||
AggregateResult result = executorService.submit(() -> merge(results)).get();
|
||||
// This particular sequence of Futures is so that both merge and munmap() can work as shards
|
||||
// finish their computation without blocking on the entire set of shards to complete. In
|
||||
// particular, munmap() doesn't need to wait on merge.
|
||||
// First, submit a task to merge the results and then submit a task to cleanup bytebuffers
|
||||
// from completed shards.
|
||||
Future<AggregateResult> resultFutures = executorService.submit(() -> merge(results));
|
||||
// Note that munmap() is serial and not parallel and hence we use just one thread.
|
||||
executorService.submit(() -> closeByteBuffers(results, shardQueue));
|
||||
|
||||
AggregateResult result = resultFutures.get();
|
||||
Tracing.recordEvent("Merge results received");
|
||||
|
||||
// Note that munmap() is serial and not parallel
|
||||
executorService.submit(
|
||||
() -> {
|
||||
for (int i = 0; i < nThreads; i++) {
|
||||
shardQueue.close(i);
|
||||
}
|
||||
});
|
||||
|
||||
Tracing.recordEvent("Waiting for executor shutdown");
|
||||
|
||||
Tracing.recordEvent("About to shutdown executor and wait");
|
||||
executorService.shutdown();
|
||||
executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS);
|
||||
|
||||
Tracing.recordEvent("Executor terminated");
|
||||
Tracing.analyzeWorkThreads("cleaner", nThreads);
|
||||
Tracing.recordEvent("After cleaner finish printed");
|
||||
|
||||
Tracing.analyzeWorkThreads(nThreads);
|
||||
return result;
|
||||
}
|
||||
|
||||
private void closeByteBuffers(
|
||||
List<Future<AggregateResult>> results, LazyShardQueue shardQueue) {
|
||||
int n = results.size();
|
||||
boolean[] isDone = new boolean[n];
|
||||
int remaining = results.size();
|
||||
while (remaining > 0) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
if (!isDone[i] && results.get(i).isDone()) {
|
||||
remaining--;
|
||||
isDone[i] = true;
|
||||
shardQueue.close("Ending Cleaner", i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private AggregateResult merge(List<Future<AggregateResult>> results)
|
||||
throws ExecutionException, InterruptedException {
|
||||
Tracing.recordEvent("Merge start time");
|
||||
@ -516,7 +559,6 @@ public class CalculateAverage_vemana {
|
||||
}
|
||||
}
|
||||
Tracing.recordEvent("Merge end time");
|
||||
Tracing.analyzeWorkThreads("shard", results.size());
|
||||
return new AggregateResult(output);
|
||||
}
|
||||
}
|
||||
@ -532,6 +574,7 @@ public class CalculateAverage_vemana {
|
||||
private final long commonChunkSize;
|
||||
private final AtomicLong commonPool;
|
||||
private final long effectiveFileSize;
|
||||
private final boolean fakeAdvance;
|
||||
private final long fileSize;
|
||||
private final long[] perThreadData;
|
||||
private final RandomAccessFile raf;
|
||||
@ -543,8 +586,11 @@ public class CalculateAverage_vemana {
|
||||
int shards,
|
||||
double commonChunkFraction,
|
||||
int commonChunkSizeBits,
|
||||
int fileTailReservedBytes)
|
||||
int fileTailReservedBytes,
|
||||
double munmapFraction,
|
||||
boolean fakeAdvance)
|
||||
throws IOException {
|
||||
this.fakeAdvance = fakeAdvance;
|
||||
Checks.checkArg(commonChunkFraction < 0.9 && commonChunkFraction >= 0);
|
||||
Checks.checkArg(fileTailReservedBytes >= 0);
|
||||
this.raf = new RandomAccessFile(filePath.toFile(), "r");
|
||||
@ -580,8 +626,8 @@ public class CalculateAverage_vemana {
|
||||
// its work, where R = relative speed of unmap compared to the computation.
|
||||
// For our problem, R ~ 75 because unmap unmaps 30GB/sec (but, it is serial) while
|
||||
// cores go through data at the rate of 400MB/sec.
|
||||
perThreadData[pos + 3] = (long) (currentChunks * (0.03 * (shards - i)));
|
||||
perThreadData[pos + 4] = 1;
|
||||
perThreadData[pos + 3] = (long) (currentChunks * (munmapFraction * (shards - i)));
|
||||
perThreadData[pos + 4] = 1; // true iff munmap() hasn't been triggered yet
|
||||
currentStart += currentChunks * chunkSize;
|
||||
remainingChunks -= currentChunks;
|
||||
}
|
||||
@ -596,8 +642,8 @@ public class CalculateAverage_vemana {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close(int shardIdx) {
|
||||
byteRanges[shardIdx << 4].close(shardIdx);
|
||||
public void close(String closerId, int shardIdx) {
|
||||
byteRanges[shardIdx << 4].close(closerId, shardIdx);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -616,14 +662,18 @@ public class CalculateAverage_vemana {
|
||||
public ByteRange take(int shardIdx) {
|
||||
// Try for thread local range
|
||||
final int pos = shardIdx << 4;
|
||||
long rangeStart = perThreadData[pos];
|
||||
final long chunkEnd = perThreadData[pos + 1];
|
||||
final long rangeStart;
|
||||
final long rangeEnd;
|
||||
|
||||
if (rangeStart < chunkEnd) {
|
||||
if (perThreadData[pos + 2] >= 1) {
|
||||
rangeStart = perThreadData[pos];
|
||||
rangeEnd = rangeStart + chunkSize;
|
||||
perThreadData[pos] = rangeEnd;
|
||||
// Don't do this in the if-check; it causes negative values that trigger intermediate
|
||||
// cleanup
|
||||
perThreadData[pos + 2]--;
|
||||
if (!fakeAdvance) {
|
||||
perThreadData[pos] = rangeEnd;
|
||||
}
|
||||
}
|
||||
else {
|
||||
rangeStart = commonPool.getAndAdd(commonChunkSize);
|
||||
@ -634,8 +684,8 @@ public class CalculateAverage_vemana {
|
||||
rangeEnd = rangeStart + commonChunkSize;
|
||||
}
|
||||
|
||||
if (perThreadData[pos + 2] <= perThreadData[pos + 3] && perThreadData[pos + 4] > 0) {
|
||||
if (attemptClose(shardIdx)) {
|
||||
if (perThreadData[pos + 2] < perThreadData[pos + 3] && perThreadData[pos + 4] > 0) {
|
||||
if (attemptIntermediateClose(shardIdx)) {
|
||||
perThreadData[pos + 4]--;
|
||||
}
|
||||
}
|
||||
@ -645,9 +695,9 @@ public class CalculateAverage_vemana {
|
||||
return chunk;
|
||||
}
|
||||
|
||||
private boolean attemptClose(int shardIdx) {
|
||||
private boolean attemptIntermediateClose(int shardIdx) {
|
||||
if (seqLock.acquire()) {
|
||||
byteRanges[shardIdx << 4].close(shardIdx);
|
||||
close("Intermediate Cleaner", shardIdx);
|
||||
seqLock.release();
|
||||
return true;
|
||||
}
|
||||
@ -964,12 +1014,22 @@ public class CalculateAverage_vemana {
|
||||
|
||||
static class Tracing {
|
||||
|
||||
private static final long[] cleanerTimes = new long[1 << 6 << 1];
|
||||
private static final long[] threadTimes = new long[1 << 6 << 1];
|
||||
private static final Map<String, ThreadTimingsArray> knownWorkThreadEvents;
|
||||
private static long startTime;
|
||||
|
||||
static void analyzeWorkThreads(String id, int nThreads) {
|
||||
printTimingsAnalysis(id + " Stats", nThreads, timingsArray(id));
|
||||
static {
|
||||
// Maintain the ordering to be chronological in execution
|
||||
// Map.of(..) screws up ordering
|
||||
knownWorkThreadEvents = new LinkedHashMap<>();
|
||||
for (String id : List.of("Shard", "Intermediate Cleaner", "Ending Cleaner")) {
|
||||
knownWorkThreadEvents.put(id, new ThreadTimingsArray(id, 1 << 6 << 1));
|
||||
}
|
||||
}
|
||||
|
||||
static void analyzeWorkThreads(int nThreads) {
|
||||
for (ThreadTimingsArray array : knownWorkThreadEvents.values()) {
|
||||
errPrint(array.analyze(nThreads));
|
||||
}
|
||||
}
|
||||
|
||||
static void recordAppStart() {
|
||||
@ -981,11 +1041,11 @@ public class CalculateAverage_vemana {
|
||||
}
|
||||
|
||||
static void recordWorkEnd(String id, int threadId) {
|
||||
timingsArray(id)[2 * threadId + 1] = System.nanoTime();
|
||||
knownWorkThreadEvents.get(id).recordEnd(threadId);
|
||||
}
|
||||
|
||||
static void recordWorkStart(String id, int threadId) {
|
||||
timingsArray(id)[2 * threadId] = System.nanoTime();
|
||||
knownWorkThreadEvents.get(id).recordStart(threadId);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -998,57 +1058,78 @@ public class CalculateAverage_vemana {
|
||||
errPrint(STR."\{message} = \{(nanoTime - startTime) / 1_000_000}ms");
|
||||
}
|
||||
|
||||
private static void printTimingsAnalysis(String header, int nThreads, long[] timestamps) {
|
||||
long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE;
|
||||
long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE;
|
||||
long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE;
|
||||
public static class ThreadTimingsArray {
|
||||
|
||||
long[] durationsMs = new long[nThreads];
|
||||
long[] completionsMs = new long[nThreads];
|
||||
long[] beginMs = new long[nThreads];
|
||||
for (int i = 0; i < nThreads; i++) {
|
||||
long durationNs = timestamps[2 * i + 1] - timestamps[2 * i];
|
||||
durationsMs[i] = durationNs / 1_000_000;
|
||||
completionsMs[i] = (timestamps[2 * i + 1] - startTime) / 1_000_000;
|
||||
beginMs[i] = (timestamps[2 * i] - startTime) / 1_000_000;
|
||||
private static String toString(long[] array) {
|
||||
return Arrays.stream(array)
|
||||
.map(x -> x < 0 ? -1 : x)
|
||||
.mapToObj(x -> String.format("%6d", x))
|
||||
.collect(Collectors.joining(", ", "[ ", " ]"));
|
||||
}
|
||||
|
||||
minDuration = Math.min(minDuration, durationNs);
|
||||
maxDuration = Math.max(maxDuration, durationNs);
|
||||
private final String id;
|
||||
private final long[] timestamps;
|
||||
private boolean hasData = false;
|
||||
|
||||
minBegin = Math.min(minBegin, timestamps[2 * i]);
|
||||
maxBegin = Math.max(maxBegin, timestamps[2 * i]);
|
||||
public ThreadTimingsArray(String id, int maxSize) {
|
||||
this.timestamps = new long[maxSize];
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1]);
|
||||
minCompletion = Math.min(minCompletion, timestamps[2 * i + 1]);
|
||||
}
|
||||
errPrint(
|
||||
STR."""
|
||||
public String analyze(int nThreads) {
|
||||
if (!hasData) {
|
||||
return "%s has no thread timings data".formatted(id);
|
||||
}
|
||||
Checks.checkArg(nThreads <= timestamps.length);
|
||||
long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE;
|
||||
long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE;
|
||||
long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE;
|
||||
|
||||
long[] durationsMs = new long[nThreads];
|
||||
long[] completionsMs = new long[nThreads];
|
||||
long[] beginMs = new long[nThreads];
|
||||
for (int i = 0; i < nThreads; i++) {
|
||||
long durationNs = timestamps[2 * i + 1] - timestamps[2 * i];
|
||||
durationsMs[i] = durationNs / 1_000_000;
|
||||
completionsMs[i] = (timestamps[2 * i + 1] - startTime) / 1_000_000;
|
||||
beginMs[i] = (timestamps[2 * i] - startTime) / 1_000_000;
|
||||
|
||||
minDuration = Math.min(minDuration, durationNs);
|
||||
maxDuration = Math.max(maxDuration, durationNs);
|
||||
|
||||
minBegin = Math.min(minBegin, timestamps[2 * i] - startTime);
|
||||
maxBegin = Math.max(maxBegin, timestamps[2 * i] - startTime);
|
||||
|
||||
maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1] - startTime);
|
||||
minCompletion = Math.min(minCompletion, timestamps[2 * i + 1] - startTime);
|
||||
}
|
||||
return STR."""
|
||||
-------------------------------------------------------------------------------------------
|
||||
\{header}
|
||||
\{id} Stats
|
||||
-------------------------------------------------------------------------------------------
|
||||
Max duration = \{maxDuration / 1_000_000} ms
|
||||
Min duration = \{minDuration / 1_000_000} ms
|
||||
Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms
|
||||
Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms [\{maxCompletion / 1_000_000} - \{minBegin / 1_000_000} ]
|
||||
Completion Timespan[max(end)-min(end)] = \{(maxCompletion - minCompletion) / 1_000_000} ms
|
||||
Begin Timespan[max(begin)-min(begin)] = \{(maxBegin - minBegin) / 1_000_000} ms
|
||||
Durations = \{toString(durationsMs)} in ms
|
||||
Begin Timestamps = \{toString(beginMs)} in ms
|
||||
Completion Timestamps = \{toString(completionsMs)} in ms
|
||||
""");
|
||||
}
|
||||
Average Duration = \{Arrays.stream(durationsMs)
|
||||
.average()
|
||||
.getAsDouble()} ms
|
||||
Durations = \{toString(durationsMs)} ms
|
||||
Begin Timestamps = \{toString(beginMs)} ms
|
||||
Completion Timestamps = \{toString(completionsMs)} ms
|
||||
""";
|
||||
}
|
||||
|
||||
private static long[] timingsArray(String id) {
|
||||
return switch (id) {
|
||||
case "cleaner" -> cleanerTimes;
|
||||
case "shard" -> threadTimes;
|
||||
default -> throw new RuntimeException("");
|
||||
};
|
||||
}
|
||||
public void recordEnd(int idx) {
|
||||
timestamps[2 * idx + 1] = System.nanoTime();
|
||||
hasData = true;
|
||||
}
|
||||
|
||||
private static String toString(long[] array) {
|
||||
return Arrays.stream(array)
|
||||
.mapToObj(x -> String.format("%6d", x))
|
||||
.collect(Collectors.joining(", ", "[ ", " ]"));
|
||||
public void recordStart(int idx) {
|
||||
timestamps[2 * idx] = System.nanoTime();
|
||||
hasData = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user