Squashing a bunch of commits together. (#428)

Commit#2; Uplift of 7% using native byteorder from ByteBuffer.
Commit#1: Minor changes to formatting.

Co-authored-by: vemana <vemana.github@gmail.com>
This commit is contained in:
Vemana 2024-01-16 00:40:50 +05:30 committed by GitHub
parent 702d41df15
commit 6fe395cbae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -41,55 +41,54 @@ import java.util.stream.Collectors;
* remain readable for a majority of SWEs. At a high level, the approach relies on a few principles
* listed herein.
*
* <p>
* [Exploit Parallelism] Distribute the work into Shards. Separate threads (one per core) process
* <p>[Exploit Parallelism] Distribute the work into Shards. Separate threads (one per core) process
* Shards and follow it up by merging the results. parallelStream() is appealing but carries
* potential run-time variance (i.e. std. deviation) penalties based on informal testing. Variance
* is not ideal when trying to minimize the maximum worker latency.
*
* <p>
* [Use ByteBuffers over MemorySegment] Each Shard is further divided in Chunks. This would've been
* unnecessary except that Shards are too big to be backed by ByteBuffers. Besides, MemorySegment
* appears slower than ByteBuffers. So, to use ByteBuffers, we have to use smaller chunks.
* <p>[Use ByteBuffers over MemorySegment] Each Shard is further divided in Chunks. This would've
* been unnecessary except that Shards are too big to be backed by ByteBuffers. Besides,
* MemorySegment appears slower than ByteBuffers. So, to use ByteBuffers, we have to use smaller
* chunks.
*
* <p>
* [Straggler freedom] The optimization function here is to minimize the maximal worker thread
* <p>[Straggler freedom] The optimization function here is to minimize the maximal worker thread
* completion. Law of large number averages means that all the threads will end up with similar
* amounts of work and similar completion times; but, however ever so often there could be a bad
* sharding and more importantly, Cores are not created equal; some will be throttled more than
* others. So, we have a shared {@code LazyShardQueue} that aims to distribute work to minimize the
* latest completion time.
*
* <p>
* [Work Assignment with LazyShardQueue] The queue provides each thread with its next big-chunk
* <p>[Work Assignment with LazyShardQueue] The queue provides each thread with its next big-chunk
* until X% of the work remains. Big-chunks belong to the thread and will not be provided to another
* thread. Then, it switches to providing small-chunk sizes. Small-chunks comprise the last X% of
* thread. Then, it switches to providing small-chunk sizes. Small-chunks comprise the last X% of
* work and every thread can participate in completing the chunk. Even though the queue is shared
* across threads, there's no communication across thread during the big-chunk phases. The queue is
* effectively a per-thread queue while processing big-chunks. The small-chunk phase uses an
* AtomicLong to coordinate chunk allocation across threads.
*
* <p>
* [Chunk processing] Chunk processing is typical. Process line by line. Find a hash function
* <p>[Chunk processing] Chunk processing is typical. Process line by line. Find a hash function
* (polynomial hash fns are slow, but will work fine), hash the city name, resolve conflicts using
* linear probing and then accumulate the temperature into the appropriate hash slot. The key
* element then is how fast can you identify the hash slot, read the temperature and update the new
* temperature in the slot (i.e. min, max, count).
*
* <p>
* [Cache friendliness] 7502P and my machine (7950X) offer 4MB L3 cache/core. This means we can hope
* to fit all our datastructures in L3 cache. Since SMT is turned on, the Runtime's available
* <p>[Cache friendliness] 7502P and my machine (7950X) offer 4MB L3 cache/core. This means we can
* hope to fit all our datastructures in L3 cache. Since SMT is turned on, the Runtime's available
* processors will show twice the number of actual cores and so we get 2MB L3 cache/thread. To be
* safe, we try to stay within 1.8 MB/thread and size our hashtable appropriately.
*
* <p>
* [Allocation] Since MemorySegment seemed slower than ByteBuffers, backing Chunks by bytebuffers
* <p>[Native ByteOrder is MUCH better] There was almost a 10% lift by reading ints from bytebuffers
* using native byteorder . It so happens that both the eval machine (7502P) and my machine 7950X
* use native LITTLE_ENDIAN order, which again apparently is because X86[-64] is little-endian. But,
* by default, ByteBuffers use BIG_ENDIAN order, which appears to be a somewhat strange default from
* Java.
*
* <p>[Allocation] Since MemorySegment seemed slower than ByteBuffers, backing Chunks by bytebuffers
* was the logical option. Creating one ByteBuffer per chunk was no bueno because the system doesn't
* like it (JVM runs out of mapped file handle quota). Other than that, allocation in the hot path
* was avoided.
*
* <p>
* [General approach to fast hashing and temperature reading] Here, it helps to understand the
* <p>[General approach to fast hashing and temperature reading] Here, it helps to understand the
* various bottlenecks in execution. One particular thing that I kept coming back to was to
* understand the relative costs of instructions: See
* https://www.agner.org/optimize/instruction_tables.pdf It is helpful to think of hardware as a
@ -102,24 +101,22 @@ import java.util.stream.Collectors;
* endPos" in a tight loop by breaking it into two pieces: one piece where the check will not be
* needed and a tail piece where it will be needed.
*
* <p>
* [Understand What Cores like]. Cores like to go straight and loop back. Despite good branch
* <p>[Understand What Cores like]. Cores like to go straight and loop back. Despite good branch
* prediction, performance sucks with mispredicted branches.
*
* <p>
* [JIT] Java performance requires understanding the JIT. It is helpful to understand what the JIT
* likes though it is still somewhat of a mystery to me. In general, it inlines small methods very
* well and after constant folding, it can optimize quite well across a reasonably deep call chain.
* My experience with the JIT was that everything I tried to tune it made it worse except for one
* parameter. I have a new-found respect for JIT - it likes and understands typical Java idioms.
* <p>[JIT] Java performance requires understanding the JIT. It is helpful to understand what the
* JIT likes though it is still somewhat of a mystery to me. In general, it inlines small methods
* very well and after constant folding, it can optimize quite well across a reasonably deep call
* chain. My experience with the JIT was that everything I tried to tune it made it worse except for
* one parameter. I have a new-found respect for JIT - it likes and understands typical Java idioms.
*
* <p>[Tuning] Nothing was more insightful than actually playing with various tuning parameters.
* I can have all the theories but the hardware and JIT are giant blackboxes. I used a bunch of
* tools to optimize: (1) Command line parameters to tune big and small chunk sizes etc. This was
* also very helpful in forming a mental model of the JIT. Sometimes, it would compile some methods
* and sometimes it would just run them interpreted since the compilation threshold wouldn't be
* reached for intermediate methods. (2) AsyncProfiler - this was the first line tool to understand
* cache misses and cpu time to figure where to aim the next optimization effort. (3) JitWatch -
* <p>[Tuning] Nothing was more insightful than actually playing with various tuning parameters. I
* can have all the theories but the hardware and JIT are giant blackboxes. I used a bunch of tools
* to optimize: (1) Command line parameters to tune big and small chunk sizes etc. This was also
* very helpful in forming a mental model of the JIT. Sometimes, it would compile some methods and
* sometimes it would just run them interpreted since the compilation threshold wouldn't be reached
* for intermediate methods. (2) AsyncProfiler - this was the first line tool to understand cache
* misses and cpu time to figure where to aim the next optimization effort. (3) JitWatch -
* invaluable for forming a mental model and attempting to tune the JIT.
*
* <p>[Things that didn't work]. This is a looong list and the hit rate is quite low. In general,
@ -140,12 +137,6 @@ import java.util.stream.Collectors;
*/
public class CalculateAverage_vemana {
public static void checkArg(boolean condition) {
if (!condition) {
throw new IllegalArgumentException();
}
}
public static void main(String[] args) throws Exception {
// First process in large chunks without coordination among threads
// Use chunkSizeBits for the large-chunk size
@ -184,18 +175,26 @@ public class CalculateAverage_vemana {
// - hashtableSizeBits = \{hashtableSizeBits}
// """);
System.out.println(new Runner(
Path.of("measurements.txt"),
chunkSizeBits,
commonChunkFraction,
commonChunkSizeBits,
hashtableSizeBits).getSummaryStatistics());
System.out.println(
new Runner(
Path.of("measurements.txt"),
chunkSizeBits,
commonChunkFraction,
commonChunkSizeBits,
hashtableSizeBits)
.getSummaryStatistics());
}
public interface LazyShardQueue {
public record AggregateResult(Map<String, Stat> tempStats) {
ByteRange take(int shardIdx);
@Override
public String toString() {
return this.tempStats().entrySet().stream()
.sorted(Entry.comparingByKey())
.map(entry -> "%s=%s".formatted(entry.getKey(), entry.getValue()))
.collect(Collectors.joining(", ", "{", "}"));
}
}
// Mutable to avoid allocation
public static class ByteRange {
@ -267,11 +266,11 @@ public class CalculateAverage_vemana {
@Override
public String toString() {
return STR."""
ByteRange {
startInBuf = \{startInBuf}
endInBuf = \{endInBuf}
}
""";
ByteRange {
startInBuf = \{startInBuf}
endInBuf = \{endInBuf}
}
""";
}
private long nextNewLine(long pos) {
@ -285,6 +284,7 @@ public class CalculateAverage_vemana {
private void setByteBufferToRange(long start, long end) {
try {
byteBuffer = raf.getChannel().map(MapMode.READ_ONLY, start, end - start);
byteBuffer.order(ByteOrder.nativeOrder());
}
catch (IOException e) {
throw new RuntimeException(e);
@ -292,18 +292,22 @@ public class CalculateAverage_vemana {
}
}
public record Result(Map<String, Stat> tempStats) {
public static final class Checks {
@Override
public String toString() {
return this.tempStats()
.entrySet()
.stream()
.sorted(Entry.comparingByKey())
.map(entry -> "%s=%s".formatted(entry.getKey(), entry.getValue()))
.collect(Collectors.joining(", ", "{", "}"));
public static void checkArg(boolean condition) {
if (!condition) {
throw new IllegalArgumentException();
}
}
private Checks() {
}
}
public interface LazyShardQueue {
ByteRange take(int shardIdx);
}
}
public static class Runner {
@ -314,7 +318,10 @@ public class CalculateAverage_vemana {
private final int shardSizeBits;
public Runner(
Path inputFile, int chunkSizeBits, double commonChunkFraction, int commonChunkSizeBits,
Path inputFile,
int chunkSizeBits,
double commonChunkFraction,
int commonChunkSizeBits,
int hashtableSizeBits) {
this.inputFile = inputFile;
this.shardSizeBits = chunkSizeBits;
@ -323,16 +330,12 @@ public class CalculateAverage_vemana {
this.hashtableSizeBits = hashtableSizeBits;
}
Result getSummaryStatistics() throws Exception {
AggregateResult getSummaryStatistics() throws Exception {
int processors = Runtime.getRuntime().availableProcessors();
LazyShardQueue shardQueue = new SerialLazyShardQueue(
1L << shardSizeBits,
inputFile,
processors,
commonChunkFraction,
commonChunkSizeBits);
1L << shardSizeBits, inputFile, processors, commonChunkFraction, commonChunkSizeBits);
List<Future<Result>> results = new ArrayList<>();
List<Future<AggregateResult>> results = new ArrayList<>();
ExecutorService executorService = Executors.newFixedThreadPool(
processors,
runnable -> {
@ -345,8 +348,8 @@ public class CalculateAverage_vemana {
for (int i = 0; i < processors; i++) {
final int I = i;
final Callable<Result> callable = () -> {
Result result = new ShardProcessor(shardQueue, hashtableSizeBits, I).processShard();
final Callable<AggregateResult> callable = () -> {
AggregateResult result = new ShardProcessor(shardQueue, hashtableSizeBits, I).processShard();
finishTimes[I] = System.nanoTime();
return result;
};
@ -356,7 +359,7 @@ public class CalculateAverage_vemana {
return executorService.submit(() -> merge(results)).get();
}
private Result merge(List<Future<Result>> results)
private AggregateResult merge(List<Future<AggregateResult>> results)
throws ExecutionException, InterruptedException {
Map<String, Stat> output = null;
boolean[] isDone = new boolean[results.size()];
@ -374,20 +377,20 @@ public class CalculateAverage_vemana {
for (Entry<String, Stat> entry : results.get(i).get().tempStats().entrySet()) {
output.compute(
entry.getKey(),
(key, value) -> value == null ? entry.getValue()
: Stat.merge(value, entry.getValue()));
(key, value) -> value == null ? entry.getValue() : Stat.merge(value, entry.getValue()));
}
}
}
}
}
return new Result(output);
return new AggregateResult(output);
}
private void printFinishTimes(long[] finishTimes) {
Arrays.sort(finishTimes);
int n = finishTimes.length;
System.err.println(STR."Finish Delta: \{(finishTimes[n - 1] - finishTimes[0]) / 1_000_000}ms");
System.err.println(
STR."Finish Delta: \{(finishTimes[n - 1] - finishTimes[0]) / 1_000_000}ms");
}
}
@ -405,23 +408,29 @@ public class CalculateAverage_vemana {
private final long[] nextStarts;
public SerialLazyShardQueue(
long chunkSize, Path filePath, int shards, double commonChunkFraction,
long chunkSize,
Path filePath,
int shards,
double commonChunkFraction,
int commonChunkSizeBits)
throws IOException {
checkArg(commonChunkFraction < 0.9 && commonChunkFraction >= 0);
Checks.checkArg(commonChunkFraction < 0.9 && commonChunkFraction >= 0);
var raf = new RandomAccessFile(filePath.toFile(), "r");
this.fileSize = raf.length();
// Common pool
long commonPoolStart = Math.min(
roundToNearestHigherMultipleOf(chunkSize, (long) (fileSize * (1 - commonChunkFraction))),
roundToNearestHigherMultipleOf(
chunkSize, (long) (fileSize * (1 - commonChunkFraction))),
fileSize);
this.commonPool = new AtomicLong(commonPoolStart);
this.commonChunkSize = 1L << commonChunkSizeBits;
// Distribute chunks to shards
this.nextStarts = new long[shards << 4]; // thread idx -> 16*idx to avoid cache line conflict
for (long i = 0, currentStart = 0, remainingChunks = (commonPoolStart + chunkSize - 1) / chunkSize; i < shards; i++) {
for (long i = 0,
currentStart = 0,
remainingChunks = (commonPoolStart + chunkSize - 1) / chunkSize; i < shards; i++) {
long remainingShards = shards - i;
long currentChunks = (remainingChunks + remainingShards - 1) / remainingShards;
// Shard i handles: [currentStart, currentStart + currentChunks * chunkSize)
@ -479,7 +488,7 @@ public class CalculateAverage_vemana {
this.state = new ShardProcessorState(hashtableSizeBits);
}
public Result processShard() {
public AggregateResult processShard() {
ByteRange range;
while ((range = shardQueue.take(threadIdx)) != null) {
processRange(range);
@ -497,7 +506,7 @@ public class CalculateAverage_vemana {
}
}
private Result result() {
private AggregateResult result() {
return state.result();
}
}
@ -527,30 +536,30 @@ public class CalculateAverage_vemana {
x = Integer.reverseBytes(x);
}
byte a = (byte) (x >>> 24);
byte a = (byte) (x >>> 0);
if (a == ';') {
nextPos += 1;
break;
}
byte b = (byte) (x >>> 16);
byte b = (byte) (x >>> 8);
if (b == ';') {
nextPos += 2;
hash = hash * 31 + ((0xFF000000 & x));
hash = hash * 31 + ((0xFF & x));
break;
}
byte c = (byte) (x >>> 8);
byte c = (byte) (x >>> 16);
if (c == ';') {
nextPos += 3;
hash = hash * 31 + ((0xFFFF0000 & x));
hash = hash * 31 + ((0xFFFF & x));
break;
}
byte d = (byte) (x >>> 0);
byte d = (byte) (x >>> 24);
if (d == ';') {
nextPos += 4;
hash = hash * 31 + ((0xFFFFFF00 & x));
hash = hash * 31 + ((0xFFFFFF & x));
break;
}
@ -582,16 +591,12 @@ public class CalculateAverage_vemana {
}
linearProbe(
cityLen,
hash & slotsMask,
negative ? -temperature : temperature,
mmb,
originalPos);
cityLen, hash & slotsMask, negative ? -temperature : temperature, mmb, originalPos);
return nextPos;
}
public Result result() {
public AggregateResult result() {
int N = stats.length;
TreeMap<String, Stat> map = new TreeMap<>();
for (int i = 0; i < N; i++) {
@ -599,7 +604,7 @@ public class CalculateAverage_vemana {
map.put(new String(cityNames[i]), stats[i]);
}
}
return new Result(map);
return new AggregateResult(map);
}
private byte[] copyFrom(MappedByteBuffer mmb, int offsetInMmb, int len) {
@ -642,6 +647,7 @@ public class CalculateAverage_vemana {
}
}
/** Represents aggregate stats. */
public static class Stat {
public static Stat firstReading(int temp) {