/*
* Copyright 2023 The original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.morling.onebrc;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel.MapMode;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import sun.misc.Unsafe;
/**
* Unlike its sister submission {@code CalculateAverage_vemana}, this submission employs non
* idiomatic methods such as SWAR and Unsafe.
*
*
For details on how this solution works, check the documentation on the sister submission.
*/
public class CalculateAverage_vemanaNonIdiomatic {
public static void main(String[] args) throws Exception {
String className = MethodHandles.lookup().lookupClass().getSimpleName();
System.err.println(
STR."""
------------------------------------------------
Running \{className}
-------------------------------------------------
""");
Tracing.recordAppStart();
Runtime.getRuntime()
.addShutdownHook(
new Thread(
() -> {
Tracing.recordEvent("In Shutdown hook");
}));
// First process in large chunks without coordination among threads
// Use chunkSizeBits for the large-chunk size
int chunkSizeBits = 20;
// For the last commonChunkFraction fraction of total work, use smaller chunk sizes
double commonChunkFraction = 0.03;
// Use commonChunkSizeBits for the small-chunk size
int commonChunkSizeBits = 18;
// Size of the hashtable (attempt to fit in L2 of 512KB of eval machine)
int hashtableSizeBits = className.toLowerCase().contains("nonidiomatic") ? 13 : 16;
// Reserve some number of lines at the end to give us freedom in reading LONGs past ranges
int minReservedBytesAtFileTail = 9;
// Number of threads
int nThreads = -1;
String inputFile = "measurements.txt";
// Parallelize unmap. Thread #n (n=1,2,..N) unmaps its bytebuffer when
// munmapFraction * n work remains.
double munmapFraction = 0.03;
boolean fakeAdvance = false;
for (String arg : args) {
String key = arg.substring(0, arg.indexOf('=')).trim();
String value = arg.substring(key.length() + 1).trim();
switch (key) {
case "chunkSizeBits":
chunkSizeBits = Integer.parseInt(value);
break;
case "commonChunkFraction":
commonChunkFraction = Double.parseDouble(value);
break;
case "commonChunkSizeBits":
commonChunkSizeBits = Integer.parseInt(value);
break;
case "hashtableSizeBits":
hashtableSizeBits = Integer.parseInt(value);
break;
case "inputFile":
inputFile = value;
break;
case "munmapFraction":
munmapFraction = Double.parseDouble(value);
break;
case "fakeAdvance":
fakeAdvance = Boolean.parseBoolean(value);
break;
case "nThreads":
nThreads = Integer.parseInt(value);
break;
default:
throw new IllegalArgumentException("Unknown argument: " + arg);
}
}
System.out.println(
new Runner(
Path.of(inputFile),
nThreads,
chunkSizeBits,
commonChunkFraction,
commonChunkSizeBits,
hashtableSizeBits,
minReservedBytesAtFileTail,
munmapFraction,
fakeAdvance)
.getSummaryStatistics());
Tracing.recordEvent("Final result printed");
}
public record AggregateResult(Map tempStats) {
@Override
public String toString() {
return this.tempStats().entrySet().stream()
.sorted(Map.Entry.comparingByKey())
.map(entry -> "%s=%s".formatted(entry.getKey(), entry.getValue()))
.collect(Collectors.joining(", ", "{", "}"));
}
}
// Mutable to avoid allocation
public static class ByteRange {
private static final int BUF_SIZE = 1 << 28;
private final long fileSize;
private final long maxEndPos; // Treat as if the file ends here
private final RandomAccessFile raf;
private final int shardIdx;
private final List unclosedBuffers = new ArrayList<>();
// ***************** What this is doing and why *****************
// Reading from ByteBuffer appears faster from MemorySegment, but ByteBuffer can only be
// Integer.MAX_VALUE long; Creating one byteBuffer per chunk kills native memory quota
// and JVM crashes without futher parameters.
//
// So, in this solution, create a sliding window of bytebuffers:
// - Create a large bytebuffer that spans the chunk
// - If the next chunk falls outside the byteBuffer, create another byteBuffer that spans the
// chunk. Because chunks are allocated serially, a single large (1<<30) byteBuffer spans
// many successive chunks.
// - In fact, for serial chunk allocation (which is friendly to page faulting anyway),
// the number of created ByteBuffers doesn't exceed [size of shard/(1<<30)] which is less than
// 100/thread and is comfortably below what the JVM can handle (65K) without further param
// tuning
// - This enables (relatively) allocation free chunking implementation. Our chunking impl uses
// fine grained chunking for the last say X% of work to avoid being hostage to stragglers
///////////// The PUBLIC API
public MappedByteBuffer byteBuffer;
public long endAddress; // the virtual memory address corresponding to 'endInBuf'
public int endInBuf; // where the chunk ends inside the buffer
public long startAddress; // the virtual memory address corresponding to 'startInBuf'
public int startInBuf; // where the chunk starts inside the buffer
///////////// Private State
long bufferBaseAddr; // buffer's base virtual memory address
long extentEnd; // byteBuffer's ending coordinate
long extentStart; // byteBuffer's begin coordinate
// Uninitialized; for mutability
public ByteRange(RandomAccessFile raf, long maxEndPos, int shardIdx) {
this.raf = raf;
this.maxEndPos = maxEndPos;
this.shardIdx = shardIdx;
try {
this.fileSize = raf.length();
}
catch (IOException e) {
throw new RuntimeException(e);
}
bufferCleanSlate();
}
public void close(String closerId) {
Tracing.recordWorkStart(closerId, shardIdx);
bufferCleanSlate();
for (MappedByteBuffer buf : unclosedBuffers) {
close(buf);
}
unclosedBuffers.clear();
Tracing.recordWorkEnd(closerId, shardIdx);
}
public void setRange(long rangeStart, long rangeEnd) {
if (rangeEnd + 1024 > extentEnd || rangeStart < extentStart) {
setByteBufferExtent(rangeStart, Math.min(rangeStart + BUF_SIZE, fileSize));
}
if (rangeStart > 0) {
rangeStart = 1 + nextNewLine(rangeStart);
}
else {
rangeStart = 0;
}
if (rangeEnd < maxEndPos) {
// rangeEnd = 1 + nextNewLine(rangeEnd); // not needed
rangeEnd = 1 + rangeEnd;
}
else {
rangeEnd = maxEndPos;
}
startInBuf = (int) (rangeStart - extentStart);
endInBuf = (int) (rangeEnd - extentStart);
startAddress = bufferBaseAddr + startInBuf;
endAddress = bufferBaseAddr + endInBuf;
}
@Override
public String toString() {
return STR."""
ByteRange {
shard = \{shardIdx}
extentStart = \{extentStart}
extentEnd = \{extentEnd}
startInBuf = \{startInBuf}
endInBuf = \{endInBuf}
startAddress = \{startAddress}
endAddress = \{endAddress}
}
""";
}
private void bufferCleanSlate() {
if (byteBuffer != null) {
unclosedBuffers.add(byteBuffer);
byteBuffer = null;
}
extentEnd = extentStart = bufferBaseAddr = startAddress = endAddress = -1;
}
private void close(MappedByteBuffer buffer) {
Method cleanerMethod = Reflection.findMethodNamed(buffer, "cleaner");
cleanerMethod.setAccessible(true);
Object cleaner = Reflection.invoke(buffer, cleanerMethod);
Method cleanMethod = Reflection.findMethodNamed(cleaner, "clean");
cleanMethod.setAccessible(true);
Reflection.invoke(cleaner, cleanMethod);
}
private long getBaseAddr(MappedByteBuffer buffer) {
Method addressMethod = Reflection.findMethodNamed(buffer, "address");
addressMethod.setAccessible(true);
return (long) Reflection.invoke(buffer, addressMethod);
}
private long nextNewLine(long pos) {
int nextPos = (int) (pos - extentStart);
while (byteBuffer.get(nextPos) != '\n') {
nextPos++;
}
return nextPos + extentStart;
}
/**
* Extent different from Range. Range is what needs to be processed. Extent is what the byte
* buffer can read without failing.
*/
private void setByteBufferExtent(long start, long end) {
bufferCleanSlate();
try {
byteBuffer = raf.getChannel().map(MapMode.READ_ONLY, start, end - start);
byteBuffer.order(ByteOrder.nativeOrder());
}
catch (IOException e) {
throw new RuntimeException(e);
}
extentStart = start;
extentEnd = end;
bufferBaseAddr = getBaseAddr(byteBuffer);
}
}
public static final class Checks {
public static void checkArg(boolean condition) {
if (!condition) {
throw new IllegalArgumentException();
}
}
private Checks() {
}
}
/*
* ENTRY SHAPE
* Ensure alignment boundaries. 4 bytes on 4 byte, 2 bytes on 2 byte etc.
* 32 bytes per entry.
* 96 KB L1 cache. 2048 entries should fully fit
* -------------------
* str: 14 bytes [Defined by constant STR_FIELD_LEN]
* hash: 2 bytes
* cityNameOffset: 3 bytes // Index in city names array if len > STR_FIELD_LEN bytes
* len: 1 byte // Length of string, in bytes
* sum: 4 bytes
* count: 4 bytes
* max: 2 bytes
* min: 2 bytes
*/
static class EntryData {
public static final int ENTRY_SIZE_BITS = 5;
/////////// OFFSETS ///////////////
private static final int OFFSET_STR = 0;
private static final int STR_FIELD_LEN = 14;
private static final int OFFSET_HASH = OFFSET_STR + STR_FIELD_LEN;
private static final int OFFSET_CITY_NAME_EXTRA = OFFSET_HASH + 2;
private static final int OFFSET_LEN = OFFSET_CITY_NAME_EXTRA + 3;
private static final int OFFSET_SUM = OFFSET_LEN + 1;
private static final int OFFSET_COUNT = OFFSET_SUM + 4;
private static final int OFFSET_MAX = OFFSET_COUNT + 4;
private static final int OFFSET_MIN = OFFSET_MAX + 2;
public static int strFieldLen() {
return STR_FIELD_LEN;
}
private final EntryMeta entryMeta;
private long baseAddress;
public EntryData(EntryMeta entryMeta) {
this.entryMeta = entryMeta;
}
public long baseAddress() {
return baseAddress;
}
public String cityNameString() {
int len = len();
byte[] zeBytes = new byte[len];
for (int i = 0; i < Math.min(len, strFieldLen()); i++) {
zeBytes[i] = Unsafely.readByte(baseAddress + i);
}
if (len > strFieldLen()) {
int rem = len - strFieldLen();
long ptr = entryMeta.cityNamesAddress(cityNamesOffset());
for (int i = 0; i < rem; i++) {
zeBytes[strFieldLen() + i] = Unsafely.readByte(ptr + i);
}
}
return new String(zeBytes);
}
public int cityNamesOffset() {
return Unsafely.readInt(baseAddress + OFFSET_CITY_NAME_EXTRA) & 0xFFFFFF;
}
public int count() {
return Unsafely.readInt(baseAddress + OFFSET_COUNT);
}
public short hash16() {
return Unsafely.readShort(baseAddress + OFFSET_HASH);
}
public int index() {
return (int) ((baseAddress() - entryMeta.baseAddress(0)) >> ENTRY_SIZE_BITS);
}
public void init(long srcAddr, int len, short hash16, short temperature) {
// Copy the string
Unsafely.copyMemory(srcAddr, strAddress(), Math.min(len, EntryData.strFieldLen()));
if (len > EntryData.strFieldLen()) {
int remaining = len - EntryData.strFieldLen();
int cityNamesOffset = entryMeta.getAndIncrementCityNames(remaining);
Unsafely.copyMemory(
srcAddr + EntryData.strFieldLen(),
entryMeta.cityNamesAddress(cityNamesOffset),
remaining);
setCityNameOffset(cityNamesOffset, len);
}
else {
setLen((byte) len);
}
// and then update the others
setHash16(hash16);
setSum(temperature);
setCount(1);
setMax(temperature);
setMin(temperature);
}
public boolean isPresent() {
return len() > 0;
}
public int len() {
return Unsafely.readByte(baseAddress + OFFSET_LEN);
}
public short max() {
return Unsafely.readShort(baseAddress + OFFSET_MAX);
}
public short min() {
return Unsafely.readShort(baseAddress + OFFSET_MIN);
}
public void setBaseAddress(long baseAddress) {
this.baseAddress = baseAddress;
}
public void setCityNameOffset(int cityNamesOffset, int len) {
// The 24 here is 3 bytes for Cityname extra index + 1 byte for actual len
// that writes 4 bytes in one shot. It is not an offset.
Unsafely.setInt(baseAddress + OFFSET_CITY_NAME_EXTRA, cityNamesOffset | (len << 24));
}
public void setCount(int value) {
Unsafely.setInt(baseAddress + OFFSET_COUNT, value);
}
public void setHash16(short value) {
Unsafely.setShort(baseAddress + OFFSET_HASH, value);
}
public void setIndex(int index) {
setBaseAddress(entryMeta.baseAddress(index));
}
public void setLen(byte value) {
Unsafely.setByte(baseAddress + OFFSET_LEN, value);
}
public void setMax(short value) {
Unsafely.setShort(baseAddress + OFFSET_MAX, value);
}
public void setMin(short value) {
Unsafely.setShort(baseAddress + OFFSET_MIN, value);
}
public void setSum(int value) {
Unsafely.setInt(baseAddress + OFFSET_SUM, value);
}
public Stat stat() {
return new Stat(min(), max(), sum(), count());
}
public long strAddress() {
return baseAddress + OFFSET_STR;
}
public int sum() {
return Unsafely.readInt(baseAddress + OFFSET_SUM);
}
public String toString() {
return STR."""
min = \{min()}
max = \{max()}
count = \{count()}
sum = \{sum()}
""";
}
public void update(short temperature) {
setMin((short) Math.min(min(), temperature));
setMax((short) Math.max(max(), temperature));
setCount(count() + 1);
setSum(sum() + temperature);
}
public boolean updateOnMatch(
EntryMeta entryMeta, long srcAddr, int len, short hash16, short temperature) {
// Quick paths
if (len() != len) {
return false;
}
if (hash16() != hash16) {
return false;
}
// Actual string comparison
if (len <= STR_FIELD_LEN) {
if (!Unsafely.matches(srcAddr, strAddress(), len)) {
return false;
}
}
else {
if (!Unsafely.matches(srcAddr, strAddress(), STR_FIELD_LEN)) {
return false;
}
if (!Unsafely.matches(
srcAddr + STR_FIELD_LEN,
entryMeta.cityNamesAddress(cityNamesOffset()),
len - STR_FIELD_LEN)) {
return false;
}
}
update(temperature);
return true;
}
}
/** Metadata for the collection of entries */
static class EntryMeta {
static int toIntFromUnsignedShort(short x) {
int ret = x;
if (ret < 0) {
ret += (1 << 16);
}
return ret;
}
private final long baseAddress;
private final long cityNamesBaseAddress; // For city names that overflow Entry.STR_FIELD_LEN
private final int hashMask;
private final int n_entries;
private final int n_entriesBits;
private long cityNamesEndAddress; // [cityNamesBaseAddress, cityNamesEndAddress)
EntryMeta(int n_entriesBits, EntryMeta oldEntryMeta) {
this.n_entries = 1 << n_entriesBits;
this.hashMask = (1 << n_entriesBits) - 1;
this.n_entriesBits = n_entriesBits;
this.baseAddress = Unsafely.allocateZeroedCacheLineAligned(this.n_entries << EntryData.ENTRY_SIZE_BITS);
if (oldEntryMeta == null) {
this.cityNamesBaseAddress = Unsafely.allocateZeroedCacheLineAligned(1 << 17);
this.cityNamesEndAddress = cityNamesBaseAddress;
}
else {
this.cityNamesBaseAddress = oldEntryMeta.cityNamesBaseAddress;
this.cityNamesEndAddress = oldEntryMeta.cityNamesEndAddress;
}
}
public long cityNamesAddress(int extraLenOffset) {
return cityNamesBaseAddress + extraLenOffset;
}
public int indexFromHash16(short hash16) {
return indexFromHash32(toIntFromUnsignedShort(hash16));
}
public int nEntriesBits() {
return n_entriesBits;
}
// Base Address of nth entry
long baseAddress(int n) {
return baseAddress + ((long) n << EntryData.ENTRY_SIZE_BITS);
}
// Size of each entry
int entrySizeInBytes() {
return 1 << EntryData.ENTRY_SIZE_BITS;
}
int getAndIncrementCityNames(int len) {
long ret = cityNamesEndAddress;
cityNamesEndAddress += ((len + 7) >> 3) << 3; // use aligned 8 bytes
return (int) (ret - cityNamesBaseAddress);
}
// Index of an entry with given hash32
int indexFromHash32(int hash32) {
return hash32 & hashMask;
}
// Number of entries
int nEntries() {
return n_entries;
}
int nextIndex(int index) {
return (index + 1) & hashMask;
}
}
static class Hashtable {
// State
int n_filledEntries;
// A single Entry to avoid local allocation
private EntryData entry;
private EntryMeta entryMeta;
// Invariants
// hash16 = (short) hash32
// index = hash16 & hashMask
private int hashHits = 0, hashMisses = 0;
Hashtable(int slotsBits) {
entryMeta = new EntryMeta(slotsBits, null);
this.entry = new EntryData(entryMeta);
}
public void addDataPoint(long srcAddr, int len, int hash32, short temperature) {
// hashHits++;
for (int index = entryMeta.indexFromHash32(hash32);; index = entryMeta.nextIndex(index)) {
entry.setIndex(index);
if (!entry.isPresent()) {
entry.init(srcAddr, len, (short) hash32, temperature);
onNewEntry();
return;
}
if (entry.updateOnMatch(entryMeta, srcAddr, len, (short) hash32, temperature)) {
return;
}
// hashMisses++;
}
}
public AggregateResult result() {
Map map = new LinkedHashMap<>(5_000);
for (int i = 0; i < entryMeta.nEntries(); i++) {
entry.setIndex(i);
if (entry.isPresent()) {
map.put(entry.cityNameString(), entry.stat());
}
}
System.err.println(
STR."""
HashHits = \{hashHits}
HashMisses = \{hashMisses} (\{hashMisses * 100.0 / hashHits})
""");
return new AggregateResult(map);
}
private EntryData getNewEntry(EntryData oldEntry, EntryMeta newEntryMeta) {
EntryData newEntry = new EntryData(newEntryMeta);
for (int index = newEntryMeta.indexFromHash16(oldEntry.hash16());; index = newEntryMeta.nextIndex(index)) {
newEntry.setIndex(index);
if (!newEntry.isPresent()) {
return newEntry;
}
}
}
private void onNewEntry() {
if (++n_filledEntries == 450) {
reHash(16);
}
}
private void reHash(int new_N_EntriesBits) {
EntryMeta oldEntryMeta = this.entryMeta;
EntryData oldEntry = new EntryData(oldEntryMeta);
Checks.checkArg(new_N_EntriesBits <= 16);
Checks.checkArg(new_N_EntriesBits > oldEntryMeta.nEntriesBits());
EntryMeta newEntryMeta = new EntryMeta(new_N_EntriesBits, oldEntryMeta);
for (int i = 0; i < oldEntryMeta.nEntries(); i++) {
oldEntry.setIndex(i);
if (oldEntry.isPresent()) {
Unsafely.copyMemory(
oldEntry.baseAddress(),
getNewEntry(oldEntry, newEntryMeta).baseAddress(),
oldEntryMeta.entrySizeInBytes());
}
}
this.entryMeta = newEntryMeta;
this.entry = new EntryData(this.entryMeta);
}
}
public interface LazyShardQueue {
void close(String closerId, int shardIdx);
Optional fileTailEndWork(int idx);
ByteRange take(int shardIdx);
}
static final class Reflection {
static Method findMethodNamed(Object object, String name, Class... paramTypes) {
try {
return object.getClass().getMethod(name, paramTypes);
}
catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
static Object invoke(Object receiver, Method method, Object... params) {
try {
return method.invoke(receiver, params);
}
catch (Exception e) {
throw new RuntimeException(e);
}
}
}
public static class Runner {
private final double commonChunkFraction;
private final int commonChunkSizeBits;
private final boolean fakeAdvance;
private final int hashtableSizeBits;
private final Path inputFile;
private final int minReservedBytesAtFileTail;
private final double munmapFraction;
private final int nThreads;
private final int shardSizeBits;
public Runner(
Path inputFile,
int nThreads,
int chunkSizeBits,
double commonChunkFraction,
int commonChunkSizeBits,
int hashtableSizeBits,
int minReservedBytesAtFileTail,
double munmapFraction,
boolean fakeAdvance) {
this.inputFile = inputFile;
this.nThreads = nThreads;
this.shardSizeBits = chunkSizeBits;
this.commonChunkFraction = commonChunkFraction;
this.commonChunkSizeBits = commonChunkSizeBits;
this.hashtableSizeBits = hashtableSizeBits;
this.minReservedBytesAtFileTail = minReservedBytesAtFileTail;
this.munmapFraction = munmapFraction;
this.fakeAdvance = fakeAdvance;
}
AggregateResult getSummaryStatistics() throws Exception {
int nThreads = this.nThreads < 0 ? Runtime.getRuntime().availableProcessors() : this.nThreads;
LazyShardQueue shardQueue = new SerialLazyShardQueue(
1L << shardSizeBits,
inputFile,
nThreads,
commonChunkFraction,
commonChunkSizeBits,
minReservedBytesAtFileTail,
munmapFraction,
fakeAdvance);
ExecutorService executorService = Executors.newFixedThreadPool(
nThreads,
runnable -> {
Thread thread = new Thread(runnable);
thread.setDaemon(true);
return thread;
});
List> results = new ArrayList<>();
for (int i = 0; i < nThreads; i++) {
final int shardIdx = i;
final Callable callable = () -> {
Tracing.recordWorkStart("Shard", shardIdx);
AggregateResult result = new ShardProcessor(shardQueue, hashtableSizeBits, shardIdx).processShard();
Tracing.recordWorkEnd("Shard", shardIdx);
return result;
};
results.add(executorService.submit(callable));
}
Tracing.recordEvent("Basic push time");
// This particular sequence of Futures is so that both merge and munmap() can work as shards
// finish their computation without blocking on the entire set of shards to complete. In
// particular, munmap() doesn't need to wait on merge.
// First, submit a task to merge the results and then submit a task to cleanup bytebuffers
// from completed shards.
Future resultFutures = executorService.submit(() -> merge(results));
// Note that munmap() is serial and not parallel and hence we use just one thread.
executorService.submit(() -> closeByteBuffers(results, shardQueue));
AggregateResult result = resultFutures.get();
Tracing.recordEvent("Merge results received");
Tracing.recordEvent("About to shutdown executor and wait");
executorService.shutdown();
executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS);
Tracing.recordEvent("Executor terminated");
Tracing.analyzeWorkThreads(nThreads);
return result;
}
private void closeByteBuffers(
List> results, LazyShardQueue shardQueue) {
int n = results.size();
boolean[] isDone = new boolean[n];
int remaining = results.size();
while (remaining > 0) {
for (int i = 0; i < n; i++) {
if (!isDone[i] && results.get(i).isDone()) {
remaining--;
isDone[i] = true;
shardQueue.close("Ending Cleaner", i);
}
}
}
}
private AggregateResult merge(List> results)
throws ExecutionException, InterruptedException {
Tracing.recordEvent("Merge start time");
Map output = null;
boolean[] isDone = new boolean[results.size()];
int remaining = results.size();
// Let's be naughty and spin in a busy loop
while (remaining > 0) {
for (int i = 0; i < results.size(); i++) {
if (!isDone[i] && results.get(i).isDone()) {
isDone[i] = true;
remaining--;
if (output == null) {
output = new TreeMap<>(results.get(i).get().tempStats());
}
else {
for (Entry entry : results.get(i).get().tempStats().entrySet()) {
output.compute(
entry.getKey(),
(key, value) -> value == null ? entry.getValue() : Stat.merge(value, entry.getValue()));
}
}
}
}
}
Tracing.recordEvent("Merge end time");
return new AggregateResult(output);
}
}
public static class SerialLazyShardQueue implements LazyShardQueue {
private static long roundToNearestLowerMultipleOf(long divisor, long value) {
return value / divisor * divisor;
}
private final ByteRange[] byteRanges;
private final long chunkSize;
private final long commonChunkSize;
private final AtomicLong commonPool;
private final long effectiveFileSize;
private final boolean fakeAdvance;
private final long fileSize;
private final long[] perThreadData;
private final RandomAccessFile raf;
private final SeqLock seqLock;
public SerialLazyShardQueue(
long chunkSize,
Path filePath,
int shards,
double commonChunkFraction,
int commonChunkSizeBits,
int fileTailReservedBytes,
double munmapFraction,
boolean fakeAdvance)
throws IOException {
this.fakeAdvance = fakeAdvance;
Checks.checkArg(commonChunkFraction < 0.9 && commonChunkFraction >= 0);
Checks.checkArg(fileTailReservedBytes >= 0);
this.raf = new RandomAccessFile(filePath.toFile(), "r");
this.fileSize = raf.length();
fileTailReservedBytes = fileTailReservedBytes == 0
? 0
: consumeToPreviousNewLineExclusive(raf, fileTailReservedBytes);
this.effectiveFileSize = fileSize - fileTailReservedBytes;
// Common pool
long commonPoolStart = Math.min(
roundToNearestLowerMultipleOf(
chunkSize, (long) (effectiveFileSize * (1 - commonChunkFraction))),
effectiveFileSize);
this.commonPool = new AtomicLong(commonPoolStart);
this.commonChunkSize = 1L << commonChunkSizeBits;
// Distribute chunks to shards
this.perThreadData = new long[shards << 4]; // thread idx -> 16*idx to avoid cache line conflict
for (long i = 0,
currentStart = 0,
remainingChunks = (commonPoolStart + chunkSize - 1) / chunkSize; i < shards; i++) {
long remainingShards = shards - i;
long currentChunks = (remainingChunks + remainingShards - 1) / remainingShards;
// Shard i handles: [currentStart, currentStart + currentChunks * chunkSize)
int pos = (int) i << 4;
perThreadData[pos] = currentStart; // next chunk begin
perThreadData[pos + 1] = currentStart + currentChunks * chunkSize; // shard end
perThreadData[pos + 2] = currentChunks; // active chunks remaining
// threshold below which need to shrink
// 0.03 is a practical number but the optimal strategy is this:
// Shard number N (1-based) should unmap as soon as it completes (R/(R+1))^N fraction of
// its work, where R = relative speed of unmap compared to the computation.
// For our problem, R ~ 75 because unmap unmaps 30GB/sec (but, it is serial) while
// cores go through data at the rate of 400MB/sec.
perThreadData[pos + 3] = (long) (currentChunks * (munmapFraction * (shards - i)));
perThreadData[pos + 4] = 1; // true iff munmap() hasn't been triggered yet
currentStart += currentChunks * chunkSize;
remainingChunks -= currentChunks;
}
this.chunkSize = chunkSize;
this.byteRanges = new ByteRange[shards << 4];
for (int i = 0; i < shards; i++) {
byteRanges[i << 4] = new ByteRange(raf, effectiveFileSize, i);
}
this.seqLock = new SeqLock();
}
@Override
public void close(String closerId, int shardIdx) {
byteRanges[shardIdx << 4].close(closerId);
}
@Override
public Optional fileTailEndWork(int idx) {
if (idx == 0 && effectiveFileSize < fileSize) {
ByteRange chunk = new ByteRange(raf, fileSize, 0);
chunk.setRange(
effectiveFileSize == 0 ? 0 : effectiveFileSize - 1 /* will consume newline at eFS-1 */,
fileSize);
return Optional.of(chunk);
}
return Optional.empty();
}
@Override
public ByteRange take(int shardIdx) {
// Try for thread local range
final int pos = shardIdx << 4;
final long rangeStart;
final long rangeEnd;
if (perThreadData[pos + 2] >= 1) {
rangeStart = perThreadData[pos];
rangeEnd = rangeStart + chunkSize;
// Don't do this in the if-check; it causes negative values that trigger intermediate
// cleanup
perThreadData[pos + 2]--;
if (!fakeAdvance) {
perThreadData[pos] = rangeEnd;
}
}
else {
rangeStart = commonPool.getAndAdd(commonChunkSize);
// If that's exhausted too, nothing remains!
if (rangeStart >= effectiveFileSize) {
return null;
}
rangeEnd = rangeStart + commonChunkSize;
}
if (perThreadData[pos + 2] < perThreadData[pos + 3] && perThreadData[pos + 4] > 0) {
if (attemptIntermediateClose(shardIdx)) {
perThreadData[pos + 4]--;
}
}
ByteRange chunk = byteRanges[pos];
chunk.setRange(rangeStart, rangeEnd);
return chunk;
}
private boolean attemptIntermediateClose(int shardIdx) {
if (seqLock.acquire()) {
close("Intermediate Cleaner", shardIdx);
seqLock.release();
return true;
}
return false;
}
private int consumeToPreviousNewLineExclusive(RandomAccessFile raf, int minReservedBytes) {
try {
long pos = Math.max(raf.length() - minReservedBytes - 1, -1);
if (pos < 0) {
return (int) raf.length();
}
long start = Math.max(pos - 512, 0);
ByteBuffer buf = raf.getChannel().map(MapMode.READ_ONLY, start, pos + 1 - start);
while (pos >= 0 && buf.get((int) (pos - start)) != '\n') {
pos--;
}
pos++;
return (int) (raf.length() - pos);
}
catch (Exception e) {
throw new RuntimeException(e);
}
}
}
/** A low-traffic non-blocking lock. */
static class SeqLock {
private final AtomicBoolean isOccupied = new AtomicBoolean(false);
boolean acquire() {
return !isOccupied.get() && isOccupied.compareAndSet(false, true);
}
void release() {
isOccupied.set(false);
}
}
public static class ShardProcessor {
private final int shardIdx;
private final LazyShardQueue shardQueue;
private final FastShardProcessorState state;
public ShardProcessor(LazyShardQueue shardQueue, int hashtableSizeBits, int shardIdx) {
this.shardQueue = shardQueue;
this.shardIdx = shardIdx;
this.state = new FastShardProcessorState(hashtableSizeBits);
}
public AggregateResult processShard() {
return processShardReal();
}
private void processRange(ByteRange range) {
long nextPos = range.startAddress;
while (nextPos < range.endAddress) {
nextPos = state.processLine(nextPos);
}
}
private void processRangeSlow(ByteRange range) {
long nextPos = range.startAddress;
while (nextPos < range.endAddress) {
nextPos = state.processLineSlow(nextPos);
}
}
private AggregateResult processShardReal() {
// First process the file tail work to give ourselves freedom to go past ranges in parsing
shardQueue.fileTailEndWork(shardIdx).ifPresent(this::processRangeSlow);
ByteRange range;
while ((range = shardQueue.take(shardIdx)) != null) {
processRange(range);
}
return result();
}
private AggregateResult result() {
return state.result();
}
}
public static class FastShardProcessorState {
private static final long LEADING_ONE_BIT_MASK = 0x8080808080808080L;
private static final long ONE_MASK = 0x0101010101010101L;
private static final long SEMICOLON_MASK = 0x3b3b3b3b3b3b3b3bL;
private final Hashtable hashtable;
private final Map slowProcessStats = new HashMap<>();
public FastShardProcessorState(int slotsBits) {
this.hashtable = new Hashtable(slotsBits);
Checks.checkArg(slotsBits <= 16); // since this.hashes is 'short'
}
public long processLine(long nextPos) {
final long origPos = nextPos;
// Trying to extract this into a function made it slower.. so, leaving it at inlining.
// It's a pity since the extracted version was more elegant to read
long firstLong;
int hash = 0;
// Don't run Long.numberOfTrailingZeros in hasSemiColon; it is not needed to establish
// whether there's a semicolon; only needed for pin-pointing length of the tail.
long s = hasSemicolon(firstLong = Unsafely.readLong(nextPos));
final int trailingZeroes;
if (s == 0) {
hash = doHash(firstLong);
do {
nextPos += 8;
s = hasSemicolon(Unsafely.readLong(nextPos));
} while (s == 0);
trailingZeroes = Long.numberOfTrailingZeros(s) + 1; // 8, 16, 24, .. # past ;
}
else {
trailingZeroes = Long.numberOfTrailingZeros(s) + 1; // 8, 16, 24, .. # past ;
hash = doHash(firstLong & maskOf(trailingZeroes - 8));
}
// Sometimes we do mix a tail of length 0..
nextPos += (trailingZeroes >> 3);
final int temp = readTemperature(nextPos);
hashtable.addDataPoint(origPos, (int) (nextPos - 1 - origPos), hash, (short) (temp >> 3));
return nextPos + (temp & 7);
}
/**
* A slow version which is used only for the tail part of the file. Maintaining hashcode sync
* between this and the fast version is a pain for experimentation. So, we'll simply use a naive
* approach.
*/
public long processLineSlow(long nextPos) {
byte nextByte;
ByteArrayOutputStream baos = new ByteArrayOutputStream();
while ((nextByte = Unsafely.readByte(nextPos++)) != ';') {
baos.write(nextByte);
}
int temperature = 0;
boolean negative = Unsafely.readByte(nextPos) == '-';
while ((nextByte = Unsafely.readByte(nextPos++)) != '\n') {
if (nextByte != '-' && nextByte != '.') {
temperature = temperature * 10 + (nextByte - '0');
}
}
if (negative) {
temperature = -temperature;
}
updateStat(slowProcessStats, baos.toString(), Stat.firstReading(temperature));
return nextPos;
}
public AggregateResult result() {
AggregateResult result = hashtable.result();
if (!slowProcessStats.isEmpty()) {
// bah.. just mutate the arg of the record...
for (Entry entry : slowProcessStats.entrySet()) {
updateStat(result.tempStats(), entry.getKey(), entry.getValue());
}
}
return result;
}
int readTemperature(long nextPos) {
// This Dependency chain
// read -> shift -> xor -> compare -> 2 in parallel [ shift -> read ] -> add -> shift
// Chain latency: 2 reads + 2 add + 4 logical [assuming compare = add]
// vs
// Prior Dependency chain (slightly optimized by hand)
// read -> compare to '-' -> read -> compare to '.' -> 3 in parallel [read -> imul] -> add
// Chain latency: 3 reads + 3 add + 1 mul [assuming compare = add]
long data = Unsafely.readLong(nextPos);
long d = data ^ (data >> 4);
if ((data & 0xFF) == '-') {
return TemperatureLookup.firstNeg(d >> 8) + TemperatureLookup.secondNeg(d >> 24);
}
else {
return TemperatureLookup.firstPos(d >> 0) + TemperatureLookup.secondPos(d >> 16);
}
}
private int doHash(long value) {
long hash = 31L * (int) value + (int) (value >> 32);
return (int) (hash ^ (hash >> 17) ^ (hash >> 28));
}
private long hasSemicolon(long x) {
long a = (x ^ SEMICOLON_MASK);
return (a - ONE_MASK) & (~a) & LEADING_ONE_BIT_MASK;
}
private long maskOf(int bits) {
return ~(-1L << bits);
}
private void updateStat(Map map, String key, Stat curValue) {
map.compute(key, (_, value) -> value == null ? curValue : Stat.merge(value, curValue));
}
}
/** Represents aggregate stats. */
public static class Stat {
public static Stat firstReading(int temp) {
return new Stat(temp, temp, temp, 1);
}
public static Stat merge(Stat left, Stat right) {
return new Stat(
Math.min(left.min, right.min),
Math.max(left.max, right.max),
left.sum + right.sum,
left.count + right.count);
}
long count, sum;
int min, max;
public Stat(int min, int max, long sum, long count) {
this.min = min;
this.max = max;
this.sum = sum;
this.count = count;
}
// Caution: Mutates
public Stat mergeReading(int curTemp) {
// Can this be improved furhter?
// Assuming random values for curTemp,
// min (&max) gets updated roughly log(N)/N fraction of the time (a small number)
// In the worst case, there will be at-most one branch misprediction.
if (curTemp > min) { // Mostly passes. On branch misprediction, just update min.
if (curTemp > max) { // Mostly fails. On branch misprediction, just update max.
max = curTemp;
}
}
else {
min = curTemp;
}
sum += curTemp;
count++;
return this;
}
@Override
public String toString() {
return "%.1f/%.1f/%.1f".formatted(min / 10.0, sum / 10.0 / count, max / 10.0);
}
}
/**
* Lookup table for temperature parsing.
*
*
* 0 0011-0000
* 9 0011-1001
* . 0010-1110
* \n 0000-1010
*
* Notice that there's no overlap in the last 4 bits. This means, if we are given two successive
* bytes X, Y all of which belong to the above characters, we can REVERSIBLY hash it to
* a single byte by doing 8-bit-hash = (last 4 bits of X) concat (last 4 bits of Y).
*
* Such a hash requires a few more operations than ideal. A more practical hash is:
* (X>>4) ^ Y ^ (Y >> 4). This means if you read 4 bytes after the '-',
* L = X Y Z W, where each of X Y Z W is a byte, then,
* L ^ (L >> 4) = D hash(X, Y) hash(Y, Z) hash(Z, W) where D = don't care. In other words, we
* can SWAR the hash.
*
* This has potential for minor conflicts; for e.g. (3, NewLine) collides with (0, 9). But, we
* don't have any collisions between two digits. That is (x, y) will never collide with (a, b)
* where x, y, a, b are digits (proof left as an exercise, lol). Along with a couple of other
* such no-conflict observations, it suffices for our purposes.
*
* If we just precompute some values like
* - BigTable[hash(X,Y)] = 100*X + 10*Y
* - SmallTable[hash(Z,W)] = 10*Z + W
*
* where potentially X, Y, Z, W can be '.' or '\n', (and the arithmetic adjusted), we can lookup
* the temperature pieces from BigTable and SmallTable and add them together.
*
*
* This class is an implementation of the above idea. The lookup tables being 256 ints long
* will always be resident in L1 cache. What remains then is to also add the information on how
* much input is to be consumed; i.e. count the - and newlines too. That can be piggy backed on
* top of the values.
*
*
FWIW, this lookup appears to have reduced the temperature reading overhead substantially on
* a Ryzen 7950X machine. But, it wasn't done systematically; so, YMMV.
*/
public static class TemperatureLookup {
// Second is the smaller (units place)
// First is the larger (100 & 10)
// _NEG tables simply negate the value so that call-site can always simply add the values from
// the first and second units. Call-sites adding-up First and Second units adds up the
// amount of input to consume.
// Here, 2 is the amount of bytes consumed. This informs how much the reading pointer
// should move.
// For pattern XY value = ((-100*X -10*Y) << 3) + 2 [2 = 1 for X, 1 for Y]
// For pattern Y. value = ((-10*Y) << 3) + 2 [2 = 1 for Y, 1 for .]
private static final int[] FIRST_NEG = make(true, true);
// For pattern XY value = ((100*X + 10*Y) << 3) + 2
// For pattern Y. value = ((10*Y) << 3) + 2
private static final int[] FIRST_POS = make(true, false);
// We count newline and any initial '-' as part of SECOND
// For pattern .Z value = (-Z << 3) + 2 + 2 [1 each for . and Z, 1 for newline, 1 for minus]
// For pattern Zn value = (-Z << 3) + 1 + 2 [1 for Z, 1 for newline, 1 for minus]
private static final int[] SECOND_NEG = make(false, true);
// For pattern .Z value = (Z << 3) + 2 + 1 [1 each for . and Z, 1 for newline]
// For pattern Zn value = (Z << 3) + 1 + 1 [1 for Z, 1 for newline]
private static final int[] SECOND_POS = make(false, false);
public static int firstNeg(long b) {
return FIRST_NEG[(int) (b & 255)];
}
public static int firstPos(long b) {
return FIRST_POS[(int) (b & 255)];
}
public static int secondNeg(long b) {
return SECOND_NEG[(int) (b & 255)];
}
public static int secondPos(long b) {
return SECOND_POS[(int) (b & 255)];
}
private static byte[] allDigits() {
byte[] out = new byte[10];
for (byte a = '0'; a <= '9'; a++) {
out[a - '0'] = a;
}
return out;
}
private static int hash(byte msb, byte lsb) {
// If K = [D msb lsb], then (K ^ (K>>4)) & 255 == hash(msb, lsb). D = don't care
return (msb << 4) ^ lsb ^ (lsb >> 4);
}
private static int[] make(boolean isFirst, boolean isNegative) {
int[] ret = new int[256];
boolean[] done = new boolean[256];
// Conventions: X = 100s place, Y = 10s place, Z = 1s place, n = new line
// All the cases to handle
// X Y . Z
// Y . Z n
// In little-endian order it becomes (byte-wise), shown in place value notation
// Z . Y X
// n Z . Y
// First = YX or .Y
// Second = Z. or nZ
// Pattern 'YX'
for (byte x : allDigits()) {
for (byte y : allDigits()) {
int index = hash(y, x);
// Shouldn't occur in Second
int value = isFirst ? (y - '0') * 10 + (x - '0') * 100 : 12345;
int delta = isFirst ? 2 : 12345;
update(index, isNegative ? -value : value, delta, ret, done);
}
}
// Pattern 'Z.'
for (byte z : allDigits()) {
int index = hash(z, (byte) '.');
// shouldn't occur in First
int value = isFirst ? 12345 : (z - '0');
int delta = isFirst ? 12345 : 2;
update(index, isNegative ? -value : value, delta, ret, done);
}
// Pattern '.Y'
for (byte y : allDigits()) {
int index = hash((byte) '.', y);
// Shouldn't occur in Second
int value = isFirst ? 10 * (y - '0') : 12345;
int delta = isFirst ? 2 : 12345;
update(index, isNegative ? -value : value, delta, ret, done);
}
// Pattern 'nZ'
for (byte z : allDigits()) {
int index = hash((byte) '\n', z);
// shouldn't occur in First
int value = isFirst ? 12345 : (z - '0');
int delta = isFirst ? 12345 : 1;
update(index, isNegative ? -value : value, delta, ret, done);
}
if (!isFirst) {
// Adjust the deltas to reflect how much input needs to be consumed
// need to consume the newline and any - sign in front
for (int i = 0; i < ret.length; i++) {
ret[i] += (isNegative ? 1 : 0) /* for - sign */ + 1 /* for new line */;
}
}
return ret;
}
private static void update(int index, int value, int delta, int[] ret, boolean[] done) {
index &= 255;
Checks.checkArg(!done[index]); // just a sanity check that our hashing is indeed reversible
ret[index] = (value << 3) | delta;
done[index] = true;
}
}
static class Tracing {
private static final Map knownWorkThreadEvents;
private static long startTime;
static {
// Maintain the ordering to be chronological in execution
// Map.of(..) screws up ordering
knownWorkThreadEvents = new LinkedHashMap<>();
for (String id : List.of("Shard", "Intermediate Cleaner", "Ending Cleaner", "Buffer Creation")) {
knownWorkThreadEvents.put(id, new ThreadTimingsArray(id, 1 << 10));
}
}
static void analyzeWorkThreads(int nThreads) {
for (ThreadTimingsArray array : knownWorkThreadEvents.values()) {
errPrint(array.analyze(nThreads));
}
}
static void recordAppStart() {
startTime = System.nanoTime();
printEvent("Start time", startTime);
}
static void recordEvent(String event) {
printEvent(event, System.nanoTime());
}
static void recordWorkEnd(String id, int threadId) {
knownWorkThreadEvents.get(id).recordEnd(threadId);
}
static void recordWorkStart(String id, int threadId) {
knownWorkThreadEvents.get(id).recordStart(threadId);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
private static void errPrint(String message) {
System.err.println(message);
}
private static void printEvent(String message, long nanoTime) {
errPrint(STR."\{message} = \{(nanoTime - startTime) / 1_000_000}ms");
}
public static class ThreadTimingsArray {
private static String toString(long[] array) {
return Arrays.stream(array)
.map(x -> x < 0 ? -1 : x)
.mapToObj(x -> String.format("%6d", x))
.collect(Collectors.joining(", ", "[ ", " ]"));
}
private final String id;
private final long[] timestamps;
private boolean hasData = false;
public ThreadTimingsArray(String id, int maxSize) {
this.timestamps = new long[maxSize];
this.id = id;
}
public String analyze(int nThreads) {
if (!hasData) {
return "%s has no thread timings data".formatted(id);
}
Checks.checkArg(nThreads <= timestamps.length);
long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE;
long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE;
long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE;
long[] durationsMs = new long[nThreads];
long[] completionsMs = new long[nThreads];
long[] beginMs = new long[nThreads];
for (int i = 0; i < nThreads; i++) {
long durationNs = timestamps[2 * i + 1] - timestamps[2 * i];
durationsMs[i] = durationNs / 1_000_000;
completionsMs[i] = (timestamps[2 * i + 1] - startTime) / 1_000_000;
beginMs[i] = (timestamps[2 * i] - startTime) / 1_000_000;
minDuration = Math.min(minDuration, durationNs);
maxDuration = Math.max(maxDuration, durationNs);
minBegin = Math.min(minBegin, timestamps[2 * i] - startTime);
maxBegin = Math.max(maxBegin, timestamps[2 * i] - startTime);
maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1] - startTime);
minCompletion = Math.min(minCompletion, timestamps[2 * i + 1] - startTime);
}
return STR."""
-------------------------------------------------------------------------------------------
\{id} Stats
-------------------------------------------------------------------------------------------
Max duration = \{maxDuration / 1_000_000} ms
Min duration = \{minDuration / 1_000_000} ms
Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms [\{maxCompletion / 1_000_000} - \{minBegin / 1_000_000} ]
Completion Timespan[max(end)-min(end)] = \{(maxCompletion - minCompletion) / 1_000_000} ms
Begin Timespan[max(begin)-min(begin)] = \{(maxBegin - minBegin) / 1_000_000} ms
Average Duration = \{Arrays.stream(durationsMs)
.average()
.getAsDouble()} ms
Durations = \{toString(durationsMs)} ms
Begin Timestamps = \{toString(beginMs)} ms
Completion Timestamps = \{toString(completionsMs)} ms
""";
}
public void recordEnd(int idx) {
timestamps[2 * idx + 1] = System.nanoTime();
hasData = true;
}
public void recordStart(int idx) {
timestamps[2 * idx] = System.nanoTime();
hasData = true;
}
}
}
static class Unsafely {
private static final Unsafe unsafe = getUnsafe();
public static long allocateZeroedCacheLineAligned(int size) {
long address = unsafe.allocateMemory(size + 63);
unsafe.setMemory(address, size + 63, (byte) 0);
return (address + 63) & ~63;
}
public static void copyMemory(long srcAddress, long destAddress, long byteCount) {
unsafe.copyMemory(srcAddress, destAddress, byteCount);
}
public static boolean matches(long srcAddr, long destAddress, int len) {
if (len < 8) {
return (readLong(srcAddr) & ~(-1L << (len << 3))) == (readLong(destAddress) & ~(-1L << (len << 3)));
}
if (readLong(srcAddr) != readLong(destAddress)) {
return false;
}
len -= 8;
if (len < 8) {
return (readLong(srcAddr + 8) & ~(-1L << (len << 3))) == (readLong(destAddress + 8) & ~(-1L << (len << 3)));
}
if (readLong(srcAddr + 8) != readLong(destAddress + 8)) {
return false;
}
len -= 8;
srcAddr += 16;
destAddress += 16;
int idx = 0;
for (; idx < (len & ~7); idx += 8) {
if (Unsafely.readLong(srcAddr + idx) != Unsafely.readLong(destAddress + idx)) {
return false;
}
}
if (idx < (len & ~3)) {
if (Unsafely.readInt(srcAddr + idx) != Unsafely.readInt(destAddress + idx)) {
return false;
}
idx += 4;
}
if (idx < (len & ~1)) {
if (Unsafely.readShort(srcAddr + idx) != Unsafely.readShort(destAddress + idx)) {
return false;
}
idx += 2;
}
return idx >= len || Unsafely.readByte(srcAddr + idx) == Unsafely.readByte(destAddress + idx);
}
public static byte readByte(long address) {
return unsafe.getByte(address);
}
public static int readInt(long address) {
return unsafe.getInt(address);
}
public static long readLong(long address) {
return unsafe.getLong(address);
}
public static short readShort(long address) {
return unsafe.getShort(address);
}
public static void setByte(long address, byte len) {
unsafe.putByte(address, len);
}
public static void setInt(long address, int value) {
unsafe.putInt(address, value);
}
public static void setShort(long address, short len) {
unsafe.putShort(address, len);
}
private static Unsafe getUnsafe() {
try {
Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
unsafeField.setAccessible(true);
return (Unsafe) unsafeField.get(null);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}
}