From 80cd738b4b9c1ad2d3fa5b4ae3d0f6e625e946e4 Mon Sep 17 00:00:00 2001 From: Vemana Date: Fri, 26 Jan 2024 03:10:14 +0530 Subject: [PATCH] C style code. Should be ~4secs or lower based on local testing. (#559) 1. Use Unsafe 2. Fit hashtable in L2 cache. 3. If we can find a good hash function, it can fit in L1 cache even. 4. Improve temperature parsing by using a lookup table --- calculate_average_vemanaNonIdiomatic.sh | 31 + prepare_vemanaNonIdiomatic.sh | 20 + .../CalculateAverage_vemanaNonIdiomatic.java | 1654 +++++++++++++++++ 3 files changed, 1705 insertions(+) create mode 100755 calculate_average_vemanaNonIdiomatic.sh create mode 100755 prepare_vemanaNonIdiomatic.sh create mode 100644 src/main/java/dev/morling/onebrc/CalculateAverage_vemanaNonIdiomatic.java diff --git a/calculate_average_vemanaNonIdiomatic.sh b/calculate_average_vemanaNonIdiomatic.sh new file mode 100755 index 0000000..99974ee --- /dev/null +++ b/calculate_average_vemanaNonIdiomatic.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Basics +JAVA_OPTS="" +JAVA_OPTS="$JAVA_OPTS --enable-preview" +JAVA_OPTS="$JAVA_OPTS --add-exports java.base/jdk.internal.ref=ALL-UNNAMED" +JAVA_OPTS="$JAVA_OPTS --add-opens java.base/java.nio=ALL-UNNAMED" + +# JIT parameters +JAVA_OPTS="$JAVA_OPTS -XX:+AlwaysCompileLoopMethods" + +# GC parameters +JAVA_OPTS="$JAVA_OPTS -XX:+UseParallelGC" + + +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_vemanaNonIdiomatic "$@" diff --git a/prepare_vemanaNonIdiomatic.sh b/prepare_vemanaNonIdiomatic.sh new file mode 100755 index 0000000..58dbc24 --- /dev/null +++ b/prepare_vemanaNonIdiomatic.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 + diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_vemanaNonIdiomatic.java b/src/main/java/dev/morling/onebrc/CalculateAverage_vemanaNonIdiomatic.java new file mode 100644 index 0000000..10b9c1e --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_vemanaNonIdiomatic.java @@ -0,0 +1,1654 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel.MapMode; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.TreeMap; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; +import sun.misc.Unsafe; + +/** + * Unlike its sister submission {@code CalculateAverage_vemana}, this submission employs non + * idiomatic methods such as SWAR and Unsafe. + * + *

For details on how this solution works, check the documentation on the sister submission. + */ +public class CalculateAverage_vemanaNonIdiomatic { + + public static void main(String[] args) throws Exception { + String className = MethodHandles.lookup().lookupClass().getSimpleName(); + System.err.println( + STR.""" + ------------------------------------------------ + Running \{className} + ------------------------------------------------- + """); + Tracing.recordAppStart(); + Runtime.getRuntime() + .addShutdownHook( + new Thread( + () -> { + Tracing.recordEvent("In Shutdown hook"); + })); + + // First process in large chunks without coordination among threads + // Use chunkSizeBits for the large-chunk size + int chunkSizeBits = 20; + + // For the last commonChunkFraction fraction of total work, use smaller chunk sizes + double commonChunkFraction = 0.03; + + // Use commonChunkSizeBits for the small-chunk size + int commonChunkSizeBits = 18; + + // Size of the hashtable (attempt to fit in L2 of 512KB of eval machine) + int hashtableSizeBits = className.toLowerCase().contains("nonidiomatic") ? 13 : 16; + + // Reserve some number of lines at the end to give us freedom in reading LONGs past ranges + int minReservedBytesAtFileTail = 9; + + // Number of threads + int nThreads = -1; + + String inputFile = "measurements.txt"; + + // Parallelize unmap. Thread #n (n=1,2,..N) unmaps its bytebuffer when + // munmapFraction * n work remains. + double munmapFraction = 0.03; + + boolean fakeAdvance = false; + + for (String arg : args) { + String key = arg.substring(0, arg.indexOf('=')).trim(); + String value = arg.substring(key.length() + 1).trim(); + switch (key) { + case "chunkSizeBits": + chunkSizeBits = Integer.parseInt(value); + break; + case "commonChunkFraction": + commonChunkFraction = Double.parseDouble(value); + break; + case "commonChunkSizeBits": + commonChunkSizeBits = Integer.parseInt(value); + break; + case "hashtableSizeBits": + hashtableSizeBits = Integer.parseInt(value); + break; + case "inputFile": + inputFile = value; + break; + case "munmapFraction": + munmapFraction = Double.parseDouble(value); + break; + case "fakeAdvance": + fakeAdvance = Boolean.parseBoolean(value); + break; + case "nThreads": + nThreads = Integer.parseInt(value); + break; + default: + throw new IllegalArgumentException("Unknown argument: " + arg); + } + } + + System.out.println( + new Runner( + Path.of(inputFile), + nThreads, + chunkSizeBits, + commonChunkFraction, + commonChunkSizeBits, + hashtableSizeBits, + minReservedBytesAtFileTail, + munmapFraction, + fakeAdvance) + .getSummaryStatistics()); + + Tracing.recordEvent("Final result printed"); + } + + public record AggregateResult(Map tempStats) { + + @Override + public String toString() { + return this.tempStats().entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .map(entry -> "%s=%s".formatted(entry.getKey(), entry.getValue())) + .collect(Collectors.joining(", ", "{", "}")); + } + } + + // Mutable to avoid allocation + public static class ByteRange { + + private static final int BUF_SIZE = 1 << 28; + + private final long fileSize; + private final long maxEndPos; // Treat as if the file ends here + private final RandomAccessFile raf; + private final int shardIdx; + private final List unclosedBuffers = new ArrayList<>(); + // ***************** What this is doing and why ***************** + // Reading from ByteBuffer appears faster from MemorySegment, but ByteBuffer can only be + // Integer.MAX_VALUE long; Creating one byteBuffer per chunk kills native memory quota + // and JVM crashes without futher parameters. + // + // So, in this solution, create a sliding window of bytebuffers: + // - Create a large bytebuffer that spans the chunk + // - If the next chunk falls outside the byteBuffer, create another byteBuffer that spans the + // chunk. Because chunks are allocated serially, a single large (1<<30) byteBuffer spans + // many successive chunks. + // - In fact, for serial chunk allocation (which is friendly to page faulting anyway), + // the number of created ByteBuffers doesn't exceed [size of shard/(1<<30)] which is less than + // 100/thread and is comfortably below what the JVM can handle (65K) without further param + // tuning + // - This enables (relatively) allocation free chunking implementation. Our chunking impl uses + // fine grained chunking for the last say X% of work to avoid being hostage to stragglers + + ///////////// The PUBLIC API + + public MappedByteBuffer byteBuffer; + public long endAddress; // the virtual memory address corresponding to 'endInBuf' + public int endInBuf; // where the chunk ends inside the buffer + public long startAddress; // the virtual memory address corresponding to 'startInBuf' + public int startInBuf; // where the chunk starts inside the buffer + + ///////////// Private State + + long bufferBaseAddr; // buffer's base virtual memory address + long extentEnd; // byteBuffer's ending coordinate + long extentStart; // byteBuffer's begin coordinate + + // Uninitialized; for mutability + public ByteRange(RandomAccessFile raf, long maxEndPos, int shardIdx) { + this.raf = raf; + this.maxEndPos = maxEndPos; + this.shardIdx = shardIdx; + try { + this.fileSize = raf.length(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + bufferCleanSlate(); + } + + public void close(String closerId) { + Tracing.recordWorkStart(closerId, shardIdx); + bufferCleanSlate(); + for (MappedByteBuffer buf : unclosedBuffers) { + close(buf); + } + unclosedBuffers.clear(); + Tracing.recordWorkEnd(closerId, shardIdx); + } + + public void setRange(long rangeStart, long rangeEnd) { + if (rangeEnd + 1024 > extentEnd || rangeStart < extentStart) { + setByteBufferExtent(rangeStart, Math.min(rangeStart + BUF_SIZE, fileSize)); + } + + if (rangeStart > 0) { + rangeStart = 1 + nextNewLine(rangeStart); + } + else { + rangeStart = 0; + } + + if (rangeEnd < maxEndPos) { + // rangeEnd = 1 + nextNewLine(rangeEnd); // not needed + rangeEnd = 1 + rangeEnd; + } + else { + rangeEnd = maxEndPos; + } + + startInBuf = (int) (rangeStart - extentStart); + endInBuf = (int) (rangeEnd - extentStart); + startAddress = bufferBaseAddr + startInBuf; + endAddress = bufferBaseAddr + endInBuf; + } + + @Override + public String toString() { + return STR.""" + ByteRange { + shard = \{shardIdx} + extentStart = \{extentStart} + extentEnd = \{extentEnd} + startInBuf = \{startInBuf} + endInBuf = \{endInBuf} + startAddress = \{startAddress} + endAddress = \{endAddress} + } + """; + } + + private void bufferCleanSlate() { + if (byteBuffer != null) { + unclosedBuffers.add(byteBuffer); + byteBuffer = null; + } + extentEnd = extentStart = bufferBaseAddr = startAddress = endAddress = -1; + } + + private void close(MappedByteBuffer buffer) { + Method cleanerMethod = Reflection.findMethodNamed(buffer, "cleaner"); + cleanerMethod.setAccessible(true); + Object cleaner = Reflection.invoke(buffer, cleanerMethod); + + Method cleanMethod = Reflection.findMethodNamed(cleaner, "clean"); + cleanMethod.setAccessible(true); + Reflection.invoke(cleaner, cleanMethod); + } + + private long getBaseAddr(MappedByteBuffer buffer) { + Method addressMethod = Reflection.findMethodNamed(buffer, "address"); + addressMethod.setAccessible(true); + return (long) Reflection.invoke(buffer, addressMethod); + } + + private long nextNewLine(long pos) { + int nextPos = (int) (pos - extentStart); + while (byteBuffer.get(nextPos) != '\n') { + nextPos++; + } + return nextPos + extentStart; + } + + /** + * Extent different from Range. Range is what needs to be processed. Extent is what the byte + * buffer can read without failing. + */ + private void setByteBufferExtent(long start, long end) { + bufferCleanSlate(); + try { + byteBuffer = raf.getChannel().map(MapMode.READ_ONLY, start, end - start); + byteBuffer.order(ByteOrder.nativeOrder()); + } + catch (IOException e) { + throw new RuntimeException(e); + } + extentStart = start; + extentEnd = end; + bufferBaseAddr = getBaseAddr(byteBuffer); + } + } + + public static final class Checks { + + public static void checkArg(boolean condition) { + if (!condition) { + throw new IllegalArgumentException(); + } + } + + private Checks() { + } + } + + /* + * ENTRY SHAPE + * Ensure alignment boundaries. 4 bytes on 4 byte, 2 bytes on 2 byte etc. + * 32 bytes per entry. + * 96 KB L1 cache. 2048 entries should fully fit + * ------------------- + * str: 14 bytes [Defined by constant STR_FIELD_LEN] + * hash: 2 bytes + * cityNameOffset: 3 bytes // Index in city names array if len > STR_FIELD_LEN bytes + * len: 1 byte // Length of string, in bytes + * sum: 4 bytes + * count: 4 bytes + * max: 2 bytes + * min: 2 bytes + */ + static class EntryData { + + public static final int ENTRY_SIZE_BITS = 5; + + /////////// OFFSETS /////////////// + private static final int OFFSET_STR = 0; + private static final int STR_FIELD_LEN = 14; + private static final int OFFSET_HASH = OFFSET_STR + STR_FIELD_LEN; + private static final int OFFSET_CITY_NAME_EXTRA = OFFSET_HASH + 2; + private static final int OFFSET_LEN = OFFSET_CITY_NAME_EXTRA + 3; + private static final int OFFSET_SUM = OFFSET_LEN + 1; + private static final int OFFSET_COUNT = OFFSET_SUM + 4; + private static final int OFFSET_MAX = OFFSET_COUNT + 4; + private static final int OFFSET_MIN = OFFSET_MAX + 2; + + public static int strFieldLen() { + return STR_FIELD_LEN; + } + + private final EntryMeta entryMeta; + + private long baseAddress; + + public EntryData(EntryMeta entryMeta) { + this.entryMeta = entryMeta; + } + + public long baseAddress() { + return baseAddress; + } + + public String cityNameString() { + int len = len(); + byte[] zeBytes = new byte[len]; + + for (int i = 0; i < Math.min(len, strFieldLen()); i++) { + zeBytes[i] = Unsafely.readByte(baseAddress + i); + } + + if (len > strFieldLen()) { + int rem = len - strFieldLen(); + long ptr = entryMeta.cityNamesAddress(cityNamesOffset()); + for (int i = 0; i < rem; i++) { + zeBytes[strFieldLen() + i] = Unsafely.readByte(ptr + i); + } + } + + return new String(zeBytes); + } + + public int cityNamesOffset() { + return Unsafely.readInt(baseAddress + OFFSET_CITY_NAME_EXTRA) & 0xFFFFFF; + } + + public int count() { + return Unsafely.readInt(baseAddress + OFFSET_COUNT); + } + + public short hash16() { + return Unsafely.readShort(baseAddress + OFFSET_HASH); + } + + public int index() { + return (int) ((baseAddress() - entryMeta.baseAddress(0)) >> ENTRY_SIZE_BITS); + } + + public void init(long srcAddr, int len, short hash16, short temperature) { + // Copy the string + Unsafely.copyMemory(srcAddr, strAddress(), Math.min(len, EntryData.strFieldLen())); + if (len > EntryData.strFieldLen()) { + int remaining = len - EntryData.strFieldLen(); + int cityNamesOffset = entryMeta.getAndIncrementCityNames(remaining); + Unsafely.copyMemory( + srcAddr + EntryData.strFieldLen(), + entryMeta.cityNamesAddress(cityNamesOffset), + remaining); + setCityNameOffset(cityNamesOffset, len); + } + else { + setLen((byte) len); + } + + // and then update the others + setHash16(hash16); + setSum(temperature); + setCount(1); + setMax(temperature); + setMin(temperature); + } + + public boolean isPresent() { + return len() > 0; + } + + public int len() { + return Unsafely.readByte(baseAddress + OFFSET_LEN); + } + + public short max() { + return Unsafely.readShort(baseAddress + OFFSET_MAX); + } + + public short min() { + return Unsafely.readShort(baseAddress + OFFSET_MIN); + } + + public void setBaseAddress(long baseAddress) { + this.baseAddress = baseAddress; + } + + public void setCityNameOffset(int cityNamesOffset, int len) { + // The 24 here is 3 bytes for Cityname extra index + 1 byte for actual len + // that writes 4 bytes in one shot. It is not an offset. + Unsafely.setInt(baseAddress + OFFSET_CITY_NAME_EXTRA, cityNamesOffset | (len << 24)); + } + + public void setCount(int value) { + Unsafely.setInt(baseAddress + OFFSET_COUNT, value); + } + + public void setHash16(short value) { + Unsafely.setShort(baseAddress + OFFSET_HASH, value); + } + + public void setIndex(int index) { + setBaseAddress(entryMeta.baseAddress(index)); + } + + public void setLen(byte value) { + Unsafely.setByte(baseAddress + OFFSET_LEN, value); + } + + public void setMax(short value) { + Unsafely.setShort(baseAddress + OFFSET_MAX, value); + } + + public void setMin(short value) { + Unsafely.setShort(baseAddress + OFFSET_MIN, value); + } + + public void setSum(int value) { + Unsafely.setInt(baseAddress + OFFSET_SUM, value); + } + + public Stat stat() { + return new Stat(min(), max(), sum(), count()); + } + + public long strAddress() { + return baseAddress + OFFSET_STR; + } + + public int sum() { + return Unsafely.readInt(baseAddress + OFFSET_SUM); + } + + public String toString() { + return STR.""" + min = \{min()} + max = \{max()} + count = \{count()} + sum = \{sum()} + """; + } + + public void update(short temperature) { + setMin((short) Math.min(min(), temperature)); + setMax((short) Math.max(max(), temperature)); + setCount(count() + 1); + setSum(sum() + temperature); + } + + public boolean updateOnMatch( + EntryMeta entryMeta, long srcAddr, int len, short hash16, short temperature) { + + // Quick paths + if (len() != len) { + return false; + } + if (hash16() != hash16) { + return false; + } + + // Actual string comparison + if (len <= STR_FIELD_LEN) { + if (!Unsafely.matches(srcAddr, strAddress(), len)) { + return false; + } + } + else { + if (!Unsafely.matches(srcAddr, strAddress(), STR_FIELD_LEN)) { + return false; + } + if (!Unsafely.matches( + srcAddr + STR_FIELD_LEN, + entryMeta.cityNamesAddress(cityNamesOffset()), + len - STR_FIELD_LEN)) { + return false; + } + } + update(temperature); + return true; + } + } + + /** Metadata for the collection of entries */ + static class EntryMeta { + + static int toIntFromUnsignedShort(short x) { + int ret = x; + if (ret < 0) { + ret += (1 << 16); + } + return ret; + } + + private final long baseAddress; + private final long cityNamesBaseAddress; // For city names that overflow Entry.STR_FIELD_LEN + private final int hashMask; + private final int n_entries; + private final int n_entriesBits; + private long cityNamesEndAddress; // [cityNamesBaseAddress, cityNamesEndAddress) + + EntryMeta(int n_entriesBits, EntryMeta oldEntryMeta) { + this.n_entries = 1 << n_entriesBits; + this.hashMask = (1 << n_entriesBits) - 1; + this.n_entriesBits = n_entriesBits; + this.baseAddress = Unsafely.allocateZeroedCacheLineAligned(this.n_entries << EntryData.ENTRY_SIZE_BITS); + if (oldEntryMeta == null) { + this.cityNamesBaseAddress = Unsafely.allocateZeroedCacheLineAligned(1 << 17); + this.cityNamesEndAddress = cityNamesBaseAddress; + } + else { + this.cityNamesBaseAddress = oldEntryMeta.cityNamesBaseAddress; + this.cityNamesEndAddress = oldEntryMeta.cityNamesEndAddress; + } + } + + public long cityNamesAddress(int extraLenOffset) { + return cityNamesBaseAddress + extraLenOffset; + } + + public int indexFromHash16(short hash16) { + return indexFromHash32(toIntFromUnsignedShort(hash16)); + } + + public int nEntriesBits() { + return n_entriesBits; + } + + // Base Address of nth entry + long baseAddress(int n) { + return baseAddress + ((long) n << EntryData.ENTRY_SIZE_BITS); + } + + // Size of each entry + int entrySizeInBytes() { + return 1 << EntryData.ENTRY_SIZE_BITS; + } + + int getAndIncrementCityNames(int len) { + long ret = cityNamesEndAddress; + cityNamesEndAddress += ((len + 7) >> 3) << 3; // use aligned 8 bytes + return (int) (ret - cityNamesBaseAddress); + } + + // Index of an entry with given hash32 + int indexFromHash32(int hash32) { + return hash32 & hashMask; + } + + // Number of entries + int nEntries() { + return n_entries; + } + + int nextIndex(int index) { + return (index + 1) & hashMask; + } + } + + static class Hashtable { + + // State + int n_filledEntries; + // A single Entry to avoid local allocation + private EntryData entry; + private EntryMeta entryMeta; + // Invariants + // hash16 = (short) hash32 + // index = hash16 & hashMask + private int hashHits = 0, hashMisses = 0; + + Hashtable(int slotsBits) { + entryMeta = new EntryMeta(slotsBits, null); + this.entry = new EntryData(entryMeta); + } + + public void addDataPoint(long srcAddr, int len, int hash32, short temperature) { + // hashHits++; + for (int index = entryMeta.indexFromHash32(hash32);; index = entryMeta.nextIndex(index)) { + entry.setIndex(index); + + if (!entry.isPresent()) { + entry.init(srcAddr, len, (short) hash32, temperature); + onNewEntry(); + return; + } + + if (entry.updateOnMatch(entryMeta, srcAddr, len, (short) hash32, temperature)) { + return; + } + // hashMisses++; + } + } + + public AggregateResult result() { + Map map = new LinkedHashMap<>(5_000); + for (int i = 0; i < entryMeta.nEntries(); i++) { + entry.setIndex(i); + if (entry.isPresent()) { + map.put(entry.cityNameString(), entry.stat()); + } + } + System.err.println( + STR.""" + HashHits = \{hashHits} + HashMisses = \{hashMisses} (\{hashMisses * 100.0 / hashHits}) + """); + return new AggregateResult(map); + } + + private EntryData getNewEntry(EntryData oldEntry, EntryMeta newEntryMeta) { + EntryData newEntry = new EntryData(newEntryMeta); + for (int index = newEntryMeta.indexFromHash16(oldEntry.hash16());; index = newEntryMeta.nextIndex(index)) { + newEntry.setIndex(index); + if (!newEntry.isPresent()) { + return newEntry; + } + } + } + + private void onNewEntry() { + if (++n_filledEntries == 450) { + reHash(16); + } + } + + private void reHash(int new_N_EntriesBits) { + EntryMeta oldEntryMeta = this.entryMeta; + EntryData oldEntry = new EntryData(oldEntryMeta); + Checks.checkArg(new_N_EntriesBits <= 16); + Checks.checkArg(new_N_EntriesBits > oldEntryMeta.nEntriesBits()); + EntryMeta newEntryMeta = new EntryMeta(new_N_EntriesBits, oldEntryMeta); + for (int i = 0; i < oldEntryMeta.nEntries(); i++) { + oldEntry.setIndex(i); + if (oldEntry.isPresent()) { + Unsafely.copyMemory( + oldEntry.baseAddress(), + getNewEntry(oldEntry, newEntryMeta).baseAddress(), + oldEntryMeta.entrySizeInBytes()); + } + } + this.entryMeta = newEntryMeta; + this.entry = new EntryData(this.entryMeta); + } + } + + public interface LazyShardQueue { + + void close(String closerId, int shardIdx); + + Optional fileTailEndWork(int idx); + + ByteRange take(int shardIdx); + } + + static final class Reflection { + + static Method findMethodNamed(Object object, String name, Class... paramTypes) { + try { + return object.getClass().getMethod(name, paramTypes); + } + catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + static Object invoke(Object receiver, Method method, Object... params) { + try { + return method.invoke(receiver, params); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + public static class Runner { + + private final double commonChunkFraction; + private final int commonChunkSizeBits; + private final boolean fakeAdvance; + private final int hashtableSizeBits; + private final Path inputFile; + private final int minReservedBytesAtFileTail; + private final double munmapFraction; + private final int nThreads; + private final int shardSizeBits; + + public Runner( + Path inputFile, + int nThreads, + int chunkSizeBits, + double commonChunkFraction, + int commonChunkSizeBits, + int hashtableSizeBits, + int minReservedBytesAtFileTail, + double munmapFraction, + boolean fakeAdvance) { + this.inputFile = inputFile; + this.nThreads = nThreads; + this.shardSizeBits = chunkSizeBits; + this.commonChunkFraction = commonChunkFraction; + this.commonChunkSizeBits = commonChunkSizeBits; + this.hashtableSizeBits = hashtableSizeBits; + this.minReservedBytesAtFileTail = minReservedBytesAtFileTail; + this.munmapFraction = munmapFraction; + this.fakeAdvance = fakeAdvance; + } + + AggregateResult getSummaryStatistics() throws Exception { + int nThreads = this.nThreads < 0 ? Runtime.getRuntime().availableProcessors() : this.nThreads; + + LazyShardQueue shardQueue = new SerialLazyShardQueue( + 1L << shardSizeBits, + inputFile, + nThreads, + commonChunkFraction, + commonChunkSizeBits, + minReservedBytesAtFileTail, + munmapFraction, + fakeAdvance); + + ExecutorService executorService = Executors.newFixedThreadPool( + nThreads, + runnable -> { + Thread thread = new Thread(runnable); + thread.setDaemon(true); + return thread; + }); + + List> results = new ArrayList<>(); + for (int i = 0; i < nThreads; i++) { + final int shardIdx = i; + final Callable callable = () -> { + Tracing.recordWorkStart("Shard", shardIdx); + AggregateResult result = new ShardProcessor(shardQueue, hashtableSizeBits, shardIdx).processShard(); + Tracing.recordWorkEnd("Shard", shardIdx); + return result; + }; + results.add(executorService.submit(callable)); + } + Tracing.recordEvent("Basic push time"); + + // This particular sequence of Futures is so that both merge and munmap() can work as shards + // finish their computation without blocking on the entire set of shards to complete. In + // particular, munmap() doesn't need to wait on merge. + // First, submit a task to merge the results and then submit a task to cleanup bytebuffers + // from completed shards. + Future resultFutures = executorService.submit(() -> merge(results)); + // Note that munmap() is serial and not parallel and hence we use just one thread. + executorService.submit(() -> closeByteBuffers(results, shardQueue)); + + AggregateResult result = resultFutures.get(); + Tracing.recordEvent("Merge results received"); + + Tracing.recordEvent("About to shutdown executor and wait"); + executorService.shutdown(); + executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS); + Tracing.recordEvent("Executor terminated"); + + Tracing.analyzeWorkThreads(nThreads); + return result; + } + + private void closeByteBuffers( + List> results, LazyShardQueue shardQueue) { + int n = results.size(); + boolean[] isDone = new boolean[n]; + int remaining = results.size(); + while (remaining > 0) { + for (int i = 0; i < n; i++) { + if (!isDone[i] && results.get(i).isDone()) { + remaining--; + isDone[i] = true; + shardQueue.close("Ending Cleaner", i); + } + } + } + } + + private AggregateResult merge(List> results) + throws ExecutionException, InterruptedException { + Tracing.recordEvent("Merge start time"); + Map output = null; + boolean[] isDone = new boolean[results.size()]; + int remaining = results.size(); + // Let's be naughty and spin in a busy loop + while (remaining > 0) { + for (int i = 0; i < results.size(); i++) { + if (!isDone[i] && results.get(i).isDone()) { + isDone[i] = true; + remaining--; + if (output == null) { + output = new TreeMap<>(results.get(i).get().tempStats()); + } + else { + for (Entry entry : results.get(i).get().tempStats().entrySet()) { + output.compute( + entry.getKey(), + (key, value) -> value == null ? entry.getValue() : Stat.merge(value, entry.getValue())); + } + } + } + } + } + Tracing.recordEvent("Merge end time"); + return new AggregateResult(output); + } + } + + public static class SerialLazyShardQueue implements LazyShardQueue { + + private static long roundToNearestLowerMultipleOf(long divisor, long value) { + return value / divisor * divisor; + } + + private final ByteRange[] byteRanges; + private final long chunkSize; + private final long commonChunkSize; + private final AtomicLong commonPool; + private final long effectiveFileSize; + private final boolean fakeAdvance; + private final long fileSize; + private final long[] perThreadData; + private final RandomAccessFile raf; + private final SeqLock seqLock; + + public SerialLazyShardQueue( + long chunkSize, + Path filePath, + int shards, + double commonChunkFraction, + int commonChunkSizeBits, + int fileTailReservedBytes, + double munmapFraction, + boolean fakeAdvance) + throws IOException { + this.fakeAdvance = fakeAdvance; + Checks.checkArg(commonChunkFraction < 0.9 && commonChunkFraction >= 0); + Checks.checkArg(fileTailReservedBytes >= 0); + this.raf = new RandomAccessFile(filePath.toFile(), "r"); + this.fileSize = raf.length(); + fileTailReservedBytes = fileTailReservedBytes == 0 + ? 0 + : consumeToPreviousNewLineExclusive(raf, fileTailReservedBytes); + this.effectiveFileSize = fileSize - fileTailReservedBytes; + + // Common pool + long commonPoolStart = Math.min( + roundToNearestLowerMultipleOf( + chunkSize, (long) (effectiveFileSize * (1 - commonChunkFraction))), + effectiveFileSize); + this.commonPool = new AtomicLong(commonPoolStart); + this.commonChunkSize = 1L << commonChunkSizeBits; + + // Distribute chunks to shards + this.perThreadData = new long[shards << 4]; // thread idx -> 16*idx to avoid cache line conflict + for (long i = 0, + currentStart = 0, + remainingChunks = (commonPoolStart + chunkSize - 1) / chunkSize; i < shards; i++) { + long remainingShards = shards - i; + long currentChunks = (remainingChunks + remainingShards - 1) / remainingShards; + // Shard i handles: [currentStart, currentStart + currentChunks * chunkSize) + int pos = (int) i << 4; + perThreadData[pos] = currentStart; // next chunk begin + perThreadData[pos + 1] = currentStart + currentChunks * chunkSize; // shard end + perThreadData[pos + 2] = currentChunks; // active chunks remaining + // threshold below which need to shrink + // 0.03 is a practical number but the optimal strategy is this: + // Shard number N (1-based) should unmap as soon as it completes (R/(R+1))^N fraction of + // its work, where R = relative speed of unmap compared to the computation. + // For our problem, R ~ 75 because unmap unmaps 30GB/sec (but, it is serial) while + // cores go through data at the rate of 400MB/sec. + perThreadData[pos + 3] = (long) (currentChunks * (munmapFraction * (shards - i))); + perThreadData[pos + 4] = 1; // true iff munmap() hasn't been triggered yet + currentStart += currentChunks * chunkSize; + remainingChunks -= currentChunks; + } + this.chunkSize = chunkSize; + + this.byteRanges = new ByteRange[shards << 4]; + for (int i = 0; i < shards; i++) { + byteRanges[i << 4] = new ByteRange(raf, effectiveFileSize, i); + } + + this.seqLock = new SeqLock(); + } + + @Override + public void close(String closerId, int shardIdx) { + byteRanges[shardIdx << 4].close(closerId); + } + + @Override + public Optional fileTailEndWork(int idx) { + if (idx == 0 && effectiveFileSize < fileSize) { + ByteRange chunk = new ByteRange(raf, fileSize, 0); + chunk.setRange( + effectiveFileSize == 0 ? 0 : effectiveFileSize - 1 /* will consume newline at eFS-1 */, + fileSize); + return Optional.of(chunk); + } + return Optional.empty(); + } + + @Override + public ByteRange take(int shardIdx) { + // Try for thread local range + final int pos = shardIdx << 4; + final long rangeStart; + final long rangeEnd; + + if (perThreadData[pos + 2] >= 1) { + rangeStart = perThreadData[pos]; + rangeEnd = rangeStart + chunkSize; + // Don't do this in the if-check; it causes negative values that trigger intermediate + // cleanup + perThreadData[pos + 2]--; + if (!fakeAdvance) { + perThreadData[pos] = rangeEnd; + } + } + else { + rangeStart = commonPool.getAndAdd(commonChunkSize); + // If that's exhausted too, nothing remains! + if (rangeStart >= effectiveFileSize) { + return null; + } + rangeEnd = rangeStart + commonChunkSize; + } + + if (perThreadData[pos + 2] < perThreadData[pos + 3] && perThreadData[pos + 4] > 0) { + if (attemptIntermediateClose(shardIdx)) { + perThreadData[pos + 4]--; + } + } + + ByteRange chunk = byteRanges[pos]; + chunk.setRange(rangeStart, rangeEnd); + return chunk; + } + + private boolean attemptIntermediateClose(int shardIdx) { + if (seqLock.acquire()) { + close("Intermediate Cleaner", shardIdx); + seqLock.release(); + return true; + } + return false; + } + + private int consumeToPreviousNewLineExclusive(RandomAccessFile raf, int minReservedBytes) { + try { + long pos = Math.max(raf.length() - minReservedBytes - 1, -1); + if (pos < 0) { + return (int) raf.length(); + } + + long start = Math.max(pos - 512, 0); + ByteBuffer buf = raf.getChannel().map(MapMode.READ_ONLY, start, pos + 1 - start); + while (pos >= 0 && buf.get((int) (pos - start)) != '\n') { + pos--; + } + pos++; + return (int) (raf.length() - pos); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + /** A low-traffic non-blocking lock. */ + static class SeqLock { + + private final AtomicBoolean isOccupied = new AtomicBoolean(false); + + boolean acquire() { + return !isOccupied.get() && isOccupied.compareAndSet(false, true); + } + + void release() { + isOccupied.set(false); + } + } + + public static class ShardProcessor { + + private final int shardIdx; + private final LazyShardQueue shardQueue; + private final FastShardProcessorState state; + + public ShardProcessor(LazyShardQueue shardQueue, int hashtableSizeBits, int shardIdx) { + this.shardQueue = shardQueue; + this.shardIdx = shardIdx; + this.state = new FastShardProcessorState(hashtableSizeBits); + } + + public AggregateResult processShard() { + return processShardReal(); + } + + private void processRange(ByteRange range) { + long nextPos = range.startAddress; + while (nextPos < range.endAddress) { + nextPos = state.processLine(nextPos); + } + } + + private void processRangeSlow(ByteRange range) { + long nextPos = range.startAddress; + while (nextPos < range.endAddress) { + nextPos = state.processLineSlow(nextPos); + } + } + + private AggregateResult processShardReal() { + // First process the file tail work to give ourselves freedom to go past ranges in parsing + shardQueue.fileTailEndWork(shardIdx).ifPresent(this::processRangeSlow); + + ByteRange range; + while ((range = shardQueue.take(shardIdx)) != null) { + processRange(range); + } + return result(); + } + + private AggregateResult result() { + return state.result(); + } + } + + public static class FastShardProcessorState { + + private static final long LEADING_ONE_BIT_MASK = 0x8080808080808080L; + private static final long ONE_MASK = 0x0101010101010101L; + private static final long SEMICOLON_MASK = 0x3b3b3b3b3b3b3b3bL; + private final Hashtable hashtable; + private final Map slowProcessStats = new HashMap<>(); + + public FastShardProcessorState(int slotsBits) { + this.hashtable = new Hashtable(slotsBits); + Checks.checkArg(slotsBits <= 16); // since this.hashes is 'short' + } + + public long processLine(long nextPos) { + final long origPos = nextPos; + + // Trying to extract this into a function made it slower.. so, leaving it at inlining. + // It's a pity since the extracted version was more elegant to read + long firstLong; + int hash = 0; + // Don't run Long.numberOfTrailingZeros in hasSemiColon; it is not needed to establish + // whether there's a semicolon; only needed for pin-pointing length of the tail. + long s = hasSemicolon(firstLong = Unsafely.readLong(nextPos)); + final int trailingZeroes; + if (s == 0) { + hash = doHash(firstLong); + do { + nextPos += 8; + s = hasSemicolon(Unsafely.readLong(nextPos)); + } while (s == 0); + trailingZeroes = Long.numberOfTrailingZeros(s) + 1; // 8, 16, 24, .. # past ; + } + else { + trailingZeroes = Long.numberOfTrailingZeros(s) + 1; // 8, 16, 24, .. # past ; + hash = doHash(firstLong & maskOf(trailingZeroes - 8)); + } + // Sometimes we do mix a tail of length 0.. + nextPos += (trailingZeroes >> 3); + + final int temp = readTemperature(nextPos); + hashtable.addDataPoint(origPos, (int) (nextPos - 1 - origPos), hash, (short) (temp >> 3)); + return nextPos + (temp & 7); + } + + /** + * A slow version which is used only for the tail part of the file. Maintaining hashcode sync + * between this and the fast version is a pain for experimentation. So, we'll simply use a naive + * approach. + */ + public long processLineSlow(long nextPos) { + byte nextByte; + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + while ((nextByte = Unsafely.readByte(nextPos++)) != ';') { + baos.write(nextByte); + } + + int temperature = 0; + boolean negative = Unsafely.readByte(nextPos) == '-'; + while ((nextByte = Unsafely.readByte(nextPos++)) != '\n') { + if (nextByte != '-' && nextByte != '.') { + temperature = temperature * 10 + (nextByte - '0'); + } + } + if (negative) { + temperature = -temperature; + } + + updateStat(slowProcessStats, baos.toString(), Stat.firstReading(temperature)); + return nextPos; + } + + public AggregateResult result() { + AggregateResult result = hashtable.result(); + if (!slowProcessStats.isEmpty()) { + // bah.. just mutate the arg of the record... + for (Entry entry : slowProcessStats.entrySet()) { + updateStat(result.tempStats(), entry.getKey(), entry.getValue()); + } + } + return result; + } + + int readTemperature(long nextPos) { + // This Dependency chain + // read -> shift -> xor -> compare -> 2 in parallel [ shift -> read ] -> add -> shift + // Chain latency: 2 reads + 2 add + 4 logical [assuming compare = add] + // vs + // Prior Dependency chain (slightly optimized by hand) + // read -> compare to '-' -> read -> compare to '.' -> 3 in parallel [read -> imul] -> add + // Chain latency: 3 reads + 3 add + 1 mul [assuming compare = add] + long data = Unsafely.readLong(nextPos); + long d = data ^ (data >> 4); + if ((data & 0xFF) == '-') { + return TemperatureLookup.firstNeg(d >> 8) + TemperatureLookup.secondNeg(d >> 24); + } + else { + return TemperatureLookup.firstPos(d >> 0) + TemperatureLookup.secondPos(d >> 16); + } + } + + private int doHash(long value) { + long hash = 31L * (int) value + (int) (value >> 32); + return (int) (hash ^ (hash >> 17) ^ (hash >> 28)); + } + + private long hasSemicolon(long x) { + long a = (x ^ SEMICOLON_MASK); + return (a - ONE_MASK) & (~a) & LEADING_ONE_BIT_MASK; + } + + private long maskOf(int bits) { + return ~(-1L << bits); + } + + private void updateStat(Map map, String key, Stat curValue) { + map.compute(key, (_, value) -> value == null ? curValue : Stat.merge(value, curValue)); + } + } + + /** Represents aggregate stats. */ + public static class Stat { + + public static Stat firstReading(int temp) { + return new Stat(temp, temp, temp, 1); + } + + public static Stat merge(Stat left, Stat right) { + return new Stat( + Math.min(left.min, right.min), + Math.max(left.max, right.max), + left.sum + right.sum, + left.count + right.count); + } + + long count, sum; + int min, max; + + public Stat(int min, int max, long sum, long count) { + this.min = min; + this.max = max; + this.sum = sum; + this.count = count; + } + + // Caution: Mutates + public Stat mergeReading(int curTemp) { + // Can this be improved furhter? + // Assuming random values for curTemp, + // min (&max) gets updated roughly log(N)/N fraction of the time (a small number) + // In the worst case, there will be at-most one branch misprediction. + if (curTemp > min) { // Mostly passes. On branch misprediction, just update min. + if (curTemp > max) { // Mostly fails. On branch misprediction, just update max. + max = curTemp; + } + } + else { + min = curTemp; + } + sum += curTemp; + count++; + return this; + } + + @Override + public String toString() { + return "%.1f/%.1f/%.1f".formatted(min / 10.0, sum / 10.0 / count, max / 10.0); + } + } + + /** + * Lookup table for temperature parsing. + * + *

+     * 0       0011-0000
+     * 9       0011-1001
+     * .       0010-1110
+     * \n      0000-1010
+     *
+     * Notice that there's no overlap in the last 4 bits. This means, if we are given two successive
+     * bytes X, Y all of which belong to the above characters, we can REVERSIBLY hash it to
+     * a single byte by doing 8-bit-hash = (last 4 bits of X) concat (last 4 bits of Y).
+     *
+     * Such a hash requires a few more operations than ideal. A more practical hash is:
+     * (X>>4) ^ Y ^ (Y >> 4). This means if you read 4 bytes after the '-',
+     * L = X Y Z W, where each of X Y Z W is a byte, then,
+     * L ^ (L >> 4) = D hash(X, Y) hash(Y, Z) hash(Z, W) where D = don't care. In other words, we
+     * can SWAR the hash.
+     *
+     * This has potential for minor conflicts; for e.g. (3, NewLine) collides with (0, 9). But, we
+     * don't have any collisions between two digits. That is (x, y) will never collide with (a, b)
+     * where x, y, a, b are digits (proof left as an exercise, lol). Along with a couple of other
+     * such no-conflict observations, it suffices for our purposes.
+     *
+     * If we just precompute some values like
+     * - BigTable[hash(X,Y)] = 100*X + 10*Y
+     * - SmallTable[hash(Z,W)] = 10*Z + W
+     *
+     * where potentially X, Y, Z, W can be '.' or '\n', (and the arithmetic adjusted), we can lookup
+     * the temperature pieces from BigTable and SmallTable and add them together.
+     * 
+ * + *

This class is an implementation of the above idea. The lookup tables being 256 ints long + * will always be resident in L1 cache. What remains then is to also add the information on how + * much input is to be consumed; i.e. count the - and newlines too. That can be piggy backed on + * top of the values. + * + *

FWIW, this lookup appears to have reduced the temperature reading overhead substantially on + * a Ryzen 7950X machine. But, it wasn't done systematically; so, YMMV. + */ + public static class TemperatureLookup { + + // Second is the smaller (units place) + // First is the larger (100 & 10) + + // _NEG tables simply negate the value so that call-site can always simply add the values from + // the first and second units. Call-sites adding-up First and Second units adds up the + // amount of input to consume. + + // Here, 2 is the amount of bytes consumed. This informs how much the reading pointer + // should move. + // For pattern XY value = ((-100*X -10*Y) << 3) + 2 [2 = 1 for X, 1 for Y] + // For pattern Y. value = ((-10*Y) << 3) + 2 [2 = 1 for Y, 1 for .] + private static final int[] FIRST_NEG = make(true, true); + + // For pattern XY value = ((100*X + 10*Y) << 3) + 2 + // For pattern Y. value = ((10*Y) << 3) + 2 + private static final int[] FIRST_POS = make(true, false); + + // We count newline and any initial '-' as part of SECOND + // For pattern .Z value = (-Z << 3) + 2 + 2 [1 each for . and Z, 1 for newline, 1 for minus] + // For pattern Zn value = (-Z << 3) + 1 + 2 [1 for Z, 1 for newline, 1 for minus] + private static final int[] SECOND_NEG = make(false, true); + + // For pattern .Z value = (Z << 3) + 2 + 1 [1 each for . and Z, 1 for newline] + // For pattern Zn value = (Z << 3) + 1 + 1 [1 for Z, 1 for newline] + private static final int[] SECOND_POS = make(false, false); + + public static int firstNeg(long b) { + return FIRST_NEG[(int) (b & 255)]; + } + + public static int firstPos(long b) { + return FIRST_POS[(int) (b & 255)]; + } + + public static int secondNeg(long b) { + return SECOND_NEG[(int) (b & 255)]; + } + + public static int secondPos(long b) { + return SECOND_POS[(int) (b & 255)]; + } + + private static byte[] allDigits() { + byte[] out = new byte[10]; + for (byte a = '0'; a <= '9'; a++) { + out[a - '0'] = a; + } + return out; + } + + private static int hash(byte msb, byte lsb) { + // If K = [D msb lsb], then (K ^ (K>>4)) & 255 == hash(msb, lsb). D = don't care + return (msb << 4) ^ lsb ^ (lsb >> 4); + } + + private static int[] make(boolean isFirst, boolean isNegative) { + int[] ret = new int[256]; + boolean[] done = new boolean[256]; + + // Conventions: X = 100s place, Y = 10s place, Z = 1s place, n = new line + + // All the cases to handle + // X Y . Z + // Y . Z n + + // In little-endian order it becomes (byte-wise), shown in place value notation + // Z . Y X + // n Z . Y + // First = YX or .Y + // Second = Z. or nZ + + // Pattern 'YX' + for (byte x : allDigits()) { + for (byte y : allDigits()) { + int index = hash(y, x); + // Shouldn't occur in Second + int value = isFirst ? (y - '0') * 10 + (x - '0') * 100 : 12345; + int delta = isFirst ? 2 : 12345; + update(index, isNegative ? -value : value, delta, ret, done); + } + } + + // Pattern 'Z.' + for (byte z : allDigits()) { + int index = hash(z, (byte) '.'); + // shouldn't occur in First + int value = isFirst ? 12345 : (z - '0'); + int delta = isFirst ? 12345 : 2; + update(index, isNegative ? -value : value, delta, ret, done); + } + + // Pattern '.Y' + for (byte y : allDigits()) { + int index = hash((byte) '.', y); + // Shouldn't occur in Second + int value = isFirst ? 10 * (y - '0') : 12345; + int delta = isFirst ? 2 : 12345; + update(index, isNegative ? -value : value, delta, ret, done); + } + + // Pattern 'nZ' + for (byte z : allDigits()) { + int index = hash((byte) '\n', z); + // shouldn't occur in First + int value = isFirst ? 12345 : (z - '0'); + int delta = isFirst ? 12345 : 1; + update(index, isNegative ? -value : value, delta, ret, done); + } + + if (!isFirst) { + // Adjust the deltas to reflect how much input needs to be consumed + // need to consume the newline and any - sign in front + for (int i = 0; i < ret.length; i++) { + ret[i] += (isNegative ? 1 : 0) /* for - sign */ + 1 /* for new line */; + } + } + return ret; + } + + private static void update(int index, int value, int delta, int[] ret, boolean[] done) { + index &= 255; + Checks.checkArg(!done[index]); // just a sanity check that our hashing is indeed reversible + ret[index] = (value << 3) | delta; + done[index] = true; + } + } + + static class Tracing { + + private static final Map knownWorkThreadEvents; + private static long startTime; + + static { + // Maintain the ordering to be chronological in execution + // Map.of(..) screws up ordering + knownWorkThreadEvents = new LinkedHashMap<>(); + for (String id : List.of("Shard", "Intermediate Cleaner", "Ending Cleaner", "Buffer Creation")) { + knownWorkThreadEvents.put(id, new ThreadTimingsArray(id, 1 << 10)); + } + } + + static void analyzeWorkThreads(int nThreads) { + for (ThreadTimingsArray array : knownWorkThreadEvents.values()) { + errPrint(array.analyze(nThreads)); + } + } + + static void recordAppStart() { + startTime = System.nanoTime(); + printEvent("Start time", startTime); + } + + static void recordEvent(String event) { + printEvent(event, System.nanoTime()); + } + + static void recordWorkEnd(String id, int threadId) { + knownWorkThreadEvents.get(id).recordEnd(threadId); + } + + static void recordWorkStart(String id, int threadId) { + knownWorkThreadEvents.get(id).recordStart(threadId); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////// + + private static void errPrint(String message) { + System.err.println(message); + } + + private static void printEvent(String message, long nanoTime) { + errPrint(STR."\{message} = \{(nanoTime - startTime) / 1_000_000}ms"); + } + + public static class ThreadTimingsArray { + + private static String toString(long[] array) { + return Arrays.stream(array) + .map(x -> x < 0 ? -1 : x) + .mapToObj(x -> String.format("%6d", x)) + .collect(Collectors.joining(", ", "[ ", " ]")); + } + + private final String id; + private final long[] timestamps; + private boolean hasData = false; + + public ThreadTimingsArray(String id, int maxSize) { + this.timestamps = new long[maxSize]; + this.id = id; + } + + public String analyze(int nThreads) { + if (!hasData) { + return "%s has no thread timings data".formatted(id); + } + Checks.checkArg(nThreads <= timestamps.length); + long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE; + long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE; + long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE; + + long[] durationsMs = new long[nThreads]; + long[] completionsMs = new long[nThreads]; + long[] beginMs = new long[nThreads]; + for (int i = 0; i < nThreads; i++) { + long durationNs = timestamps[2 * i + 1] - timestamps[2 * i]; + durationsMs[i] = durationNs / 1_000_000; + completionsMs[i] = (timestamps[2 * i + 1] - startTime) / 1_000_000; + beginMs[i] = (timestamps[2 * i] - startTime) / 1_000_000; + + minDuration = Math.min(minDuration, durationNs); + maxDuration = Math.max(maxDuration, durationNs); + + minBegin = Math.min(minBegin, timestamps[2 * i] - startTime); + maxBegin = Math.max(maxBegin, timestamps[2 * i] - startTime); + + maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1] - startTime); + minCompletion = Math.min(minCompletion, timestamps[2 * i + 1] - startTime); + } + return STR.""" + ------------------------------------------------------------------------------------------- + \{id} Stats + ------------------------------------------------------------------------------------------- + Max duration = \{maxDuration / 1_000_000} ms + Min duration = \{minDuration / 1_000_000} ms + Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms [\{maxCompletion / 1_000_000} - \{minBegin / 1_000_000} ] + Completion Timespan[max(end)-min(end)] = \{(maxCompletion - minCompletion) / 1_000_000} ms + Begin Timespan[max(begin)-min(begin)] = \{(maxBegin - minBegin) / 1_000_000} ms + Average Duration = \{Arrays.stream(durationsMs) + .average() + .getAsDouble()} ms + Durations = \{toString(durationsMs)} ms + Begin Timestamps = \{toString(beginMs)} ms + Completion Timestamps = \{toString(completionsMs)} ms + """; + } + + public void recordEnd(int idx) { + timestamps[2 * idx + 1] = System.nanoTime(); + hasData = true; + } + + public void recordStart(int idx) { + timestamps[2 * idx] = System.nanoTime(); + hasData = true; + } + } + } + + static class Unsafely { + + private static final Unsafe unsafe = getUnsafe(); + + public static long allocateZeroedCacheLineAligned(int size) { + long address = unsafe.allocateMemory(size + 63); + unsafe.setMemory(address, size + 63, (byte) 0); + return (address + 63) & ~63; + } + + public static void copyMemory(long srcAddress, long destAddress, long byteCount) { + unsafe.copyMemory(srcAddress, destAddress, byteCount); + } + + public static boolean matches(long srcAddr, long destAddress, int len) { + if (len < 8) { + return (readLong(srcAddr) & ~(-1L << (len << 3))) == (readLong(destAddress) & ~(-1L << (len << 3))); + } + if (readLong(srcAddr) != readLong(destAddress)) { + return false; + } + len -= 8; + + if (len < 8) { + return (readLong(srcAddr + 8) & ~(-1L << (len << 3))) == (readLong(destAddress + 8) & ~(-1L << (len << 3))); + } + if (readLong(srcAddr + 8) != readLong(destAddress + 8)) { + return false; + } + len -= 8; + srcAddr += 16; + destAddress += 16; + + int idx = 0; + for (; idx < (len & ~7); idx += 8) { + if (Unsafely.readLong(srcAddr + idx) != Unsafely.readLong(destAddress + idx)) { + return false; + } + } + + if (idx < (len & ~3)) { + if (Unsafely.readInt(srcAddr + idx) != Unsafely.readInt(destAddress + idx)) { + return false; + } + idx += 4; + } + + if (idx < (len & ~1)) { + if (Unsafely.readShort(srcAddr + idx) != Unsafely.readShort(destAddress + idx)) { + return false; + } + idx += 2; + } + + return idx >= len || Unsafely.readByte(srcAddr + idx) == Unsafely.readByte(destAddress + idx); + } + + public static byte readByte(long address) { + return unsafe.getByte(address); + } + + public static int readInt(long address) { + return unsafe.getInt(address); + } + + public static long readLong(long address) { + return unsafe.getLong(address); + } + + public static short readShort(long address) { + return unsafe.getShort(address); + } + + public static void setByte(long address, byte len) { + unsafe.putByte(address, len); + } + + public static void setInt(long address, int value) { + unsafe.putInt(address, value); + } + + public static void setShort(long address, short len) { + unsafe.putShort(address, len); + } + + private static Unsafe getUnsafe() { + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + return (Unsafe) unsafeField.get(null); + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + } +}