armandino: misc improvements (#673)

This commit is contained in:
Arman Sharif 2024-01-31 00:39:08 -08:00 committed by GitHub
parent a5ce4ba771
commit 974ddbae60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,15 +24,12 @@ import java.lang.reflect.Field;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.Stream;
import static java.nio.channels.FileChannel.MapMode.READ_ONLY;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.stream.Collectors.toMap;
public class CalculateAverage_armandino {
@ -42,19 +39,59 @@ public class CalculateAverage_armandino {
private static final int INITIAL_MAP_CAPACITY = 8192;
private static final byte SEMICOLON = 59;
private static final byte NL = 10;
private static final byte DOT = 46;
private static final byte MINUS = 45;
private static final byte ZERO_DIGIT = 48;
private static final int PRIME = 1117;
private static final int KEY_OFFSET = 0, // 100b
HASH_OFFSET = 100, // int
KEY_LENGTH_OFFSET = 104, // short
MIN_OFFSET = 106, // short
MAX_OFFSET = 108, // short
COUNT_OFFSET = 110, // int
SUM_OFFSET = 114; // long
private static final long ENTRY_SIZE = 100 // key: offset=0
+ 4 // keyHash: offset=100
+ 2 // keyLength: offset=104
+ 2 // min: 108; offset=106
+ 2 // max: 110; offset=108
+ 4 // count: 114; offset=110
+ 8; // sum: 122; offset=118
private static final Unsafe UNSAFE = getUnsafe();
public static void main(String[] args) throws Exception {
var channel = FileChannel.open(FILE, StandardOpenOption.READ);
var results = Arrays.stream(split(channel)).parallel()
.map(chunk -> new ChunkProcessor().process(chunk.start, chunk.end))
.flatMap(SimpleMap::stream)
.collect(toMap(Stats::getKey, s -> s, CalculateAverage_armandino::mergeStats, TreeMap::new));
Chunk[] chunks = split(channel);
ChunkProcessor[] processors = new ChunkProcessor[chunks.length];
for (int i = 0; i < processors.length; i++) {
processors[i] = new ChunkProcessor(chunks[i].start, chunks[i].end);
processors[i].start();
}
Map<String, Stats> results = new TreeMap<>();
for (int i = 0; i < processors.length; i++) {
processors[i].join();
final long end = processors[i].map.mapEnd;
for (long addr = processors[i].map.mapStart; addr < end; addr += ENTRY_SIZE) {
final short keyLength = UNSAFE.getShort(addr + KEY_LENGTH_OFFSET);
if (keyLength == 0)
continue;
final byte[] keyBytes = new byte[keyLength];
UNSAFE.copyMemory(null, addr, keyBytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, keyLength);
final short min = UNSAFE.getShort(addr + MIN_OFFSET);
final short max = UNSAFE.getShort(addr + MAX_OFFSET);
final int count = UNSAFE.getInt(addr + COUNT_OFFSET);
final long sum = UNSAFE.getLong(addr + SUM_OFFSET);
final Stats s = new Stats(new String(keyBytes, 0, keyLength, UTF_8), min, max, count, sum);
results.merge(s.key, s, CalculateAverage_armandino::mergeStats);
}
}
print(results.values());
}
@ -67,87 +104,69 @@ public class CalculateAverage_armandino {
return x;
}
private static class ChunkProcessor {
private final SimpleMap map = new SimpleMap(INITIAL_MAP_CAPACITY);
private static class ChunkProcessor extends Thread {
private final UnsafeMap map = new UnsafeMap(INITIAL_MAP_CAPACITY);
private SimpleMap process(final long chunkStart, final long chunkEnd) {
final long chunkStart;
final long chunkEnd;
private ChunkProcessor(long chunkStart, long chunkEnd) {
this.chunkStart = chunkStart;
this.chunkEnd = chunkEnd;
}
@Override
public void run() {
long i = chunkStart;
while (i < chunkEnd) {
final long keyAddress = i;
int keyHash = 0;
int measurement = 0;
byte b;
while ((b = UNSAFE.getByte(i++)) != SEMICOLON) {
keyHash = PRIME * keyHash + b;
}
final int keyLength = (int) (i - keyAddress - 1);
final short keyLength = (short) (i - keyAddress - 1);
final long numberWord = UNSAFE.getLong(i);
final int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000);
final short measurement = parseNumber(decimalSepPos, numberWord);
final int addOffset = (decimalSepPos >>> 3) + 3;
i += addOffset;
if ((b = UNSAFE.getByte(i++)) == MINUS) {
while ((b = UNSAFE.getByte(i++)) != DOT) {
measurement = measurement * 10 + b - ZERO_DIGIT;
}
b = UNSAFE.getByte(i);
measurement = measurement * 10 + b - ZERO_DIGIT;
measurement = -measurement;
i += 2;
}
else {
measurement = b - ZERO_DIGIT; // D1
b = UNSAFE.getByte(i); // dot or D2
if (b == DOT) {
measurement = measurement * 10 + UNSAFE.getByte(i + 1) - ZERO_DIGIT; // F
i += 3;
}
else {
measurement = measurement * 10 + b - ZERO_DIGIT; // D2
measurement = measurement * 10 + UNSAFE.getByte(i + 2) - ZERO_DIGIT; // F
i += 4; // skip NL
}
}
final Stats stats = map.putStats(keyHash, keyAddress, keyLength);
stats.min = Math.min(stats.min, measurement);
stats.max = Math.max(stats.max, measurement);
stats.sum += measurement;
stats.count++;
map.addEntry(keyHash, keyAddress, keyLength, measurement);
}
}
return map;
// credit: merykitty
private static short parseNumber(int decimalSepPos, long numberWord) {
int shift = 28 - decimalSepPos;
// signed is -1 if negative, 0 otherwise
long signed = (~numberWord << 59) >> 63;
long designMask = ~(signed & 0xFF);
// Align the number to a specific position and transform the ascii to digit value
long digits = ((numberWord & designMask) << shift) & 0x0F000F0F00L;
// Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit)
// 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
// 0x000000UU00TTHH00 + 0x00UU00TTHH000000 * 10 + 0xUU00TTHH00000000 * 100
long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
return (short) ((absValue ^ signed) - signed);
}
}
private static class Stats implements Comparable<Stats> {
private String key;
private final long keyAddress;
private final int keyLength;
private final int keyHash;
private int min = Integer.MAX_VALUE;
private int max = Integer.MIN_VALUE;
private static class Stats {
private final String key;
private int min;
private int max;
private int count;
private long sum;
private Stats(long keyAddress, int keyLength, int keyHash) {
this.keyAddress = keyAddress;
this.keyLength = keyLength;
this.keyHash = keyHash;
}
String getKey() {
if (key == null) {
var keyBytes = new byte[keyLength];
UNSAFE.copyMemory(null, keyAddress, keyBytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, keyLength);
key = new String(keyBytes, 0, keyLength, UTF_8);
}
return key;
}
@Override
public int compareTo(final Stats o) {
return getKey().compareTo(o.getKey());
Stats(final String key, final int min, final int max, final int count, final long sum) {
this.min = min;
this.max = max;
this.count = count;
this.sum = sum;
this.key = key;
}
void print(final PrintStream out) {
@ -219,90 +238,114 @@ public class CalculateAverage_armandino {
}
}
private static class SimpleMap {
private Stats[] table;
private static class UnsafeMap {
SimpleMap(int initialCapacity) {
table = new Stats[initialCapacity];
long mapStart;
long mapEnd;
int capacity; // num entries
UnsafeMap(int numEntries) {
capacity = numEntries;
final long size = ENTRY_SIZE * numEntries;
mapStart = UNSAFE.allocateMemory(size);
mapEnd = mapStart + size;
UNSAFE.setMemory(mapStart, size, (byte) 0);
}
Stream<Stats> stream() {
return Arrays.stream(table).filter(Objects::nonNull);
}
void addEntry(final int keyHash, final long keyAddress, final short keyLength, final short measurement) {
final int pos = (capacity - 1) & keyHash;
Stats putStats(final int keyHash, final long keyAddress, final int keyLength) {
final int pos = (table.length - 1) & keyHash;
long addr = mapStart + pos * ENTRY_SIZE;
int hash = UNSAFE.getInt(addr + HASH_OFFSET);
Stats stats = table[pos];
if (stats == null)
return createAt(table, keyAddress, keyLength, keyHash, pos);
if (stats.keyHash == keyHash && keysEqual(stats, keyAddress, keyLength))
return stats;
int i = pos;
while (++i < table.length) {
stats = table[i];
if (stats == null)
return createAt(table, keyAddress, keyLength, keyHash, i);
if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength))
return stats;
if (hash == 0) { // new entry
initEntry(addr, keyAddress, keyLength, measurement, keyHash);
return;
}
if (hash == keyHash && keysEqual(addr, keyAddress, keyLength)) {
updateEntry(addr, measurement);
return;
}
i = pos;
while (i-- > 0) {
stats = table[i];
if (stats == null)
return createAt(table, keyAddress, keyLength, keyHash, i);
if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength))
return stats;
}
resize();
return putStats(keyHash, keyAddress, keyLength);
}
// this can be improved to avoid clustering at the start.
// should only affect the 10k test
addr = mapStart;
private static Stats createAt(Stats[] table, long keyAddress, int keyLength, int key, int i) {
Stats stats = new Stats(keyAddress, keyLength, key);
table[i] = stats;
return stats;
}
while (addr < mapEnd) {
addr += ENTRY_SIZE;
hash = UNSAFE.getInt(addr + HASH_OFFSET);
private static boolean keysEqual(Stats stats, long keyAddress, final int keyLength) {
// credit: abeobk
long xsum = 0;
int n = keyLength & 0xF8;
for (int i = 0; i < n; i += 8) {
xsum |= (UNSAFE.getLong(stats.keyAddress + i) ^ UNSAFE.getLong(keyAddress + i));
}
return xsum == 0;
}
private void resize() {
var copy = new SimpleMap(table.length * 2);
for (Stats s : table) {
if (s != null) {
final int pos = (copy.table.length - 1) & s.keyHash;
int i = pos;
if (copy.table[i] == null) {
copy.table[i] = s;
continue;
}
while (i < copy.table.length && copy.table[i] != null) {
i++;
}
if (i == copy.table.length) {
i = pos;
while (i >= 0 && copy.table[i] != null) {
i--;
}
}
if (i < 0) {
// if we reach here it's a bug!
throw new IllegalStateException("table is full");
}
copy.table[i] = s;
if (hash == 0) {
initEntry(addr, keyAddress, keyLength, measurement, keyHash);
return;
}
if (hash == keyHash && keysEqual(addr, keyAddress, keyLength)) {
updateEntry(addr, measurement);
return;
}
}
table = copy.table;
resize(keyHash, keyAddress, keyLength, measurement);
}
private void resize(final int keyHash, final long keyAddress, final short keyLength, final short measurement) {
UnsafeMap newMap = new UnsafeMap(capacity * 2);
for (long addr = mapStart; addr < mapEnd; addr += ENTRY_SIZE) {
final short oKeyLength = UNSAFE.getShort(addr + KEY_LENGTH_OFFSET);
final int oKeyHsh = UNSAFE.getInt(addr + HASH_OFFSET);
final short oMin = UNSAFE.getShort(addr + MIN_OFFSET);
final short oMax = UNSAFE.getShort(addr + MAX_OFFSET);
final int oCount = UNSAFE.getInt(addr + COUNT_OFFSET);
final long oSum = UNSAFE.getLong(addr + SUM_OFFSET);
final int newPos = (newMap.capacity - 1) & oKeyHsh;
long newAddr = newMap.mapStart + newPos * ENTRY_SIZE;
UNSAFE.putShort(newAddr + KEY_LENGTH_OFFSET, oKeyLength);
UNSAFE.putInt(newAddr + HASH_OFFSET, oKeyHsh);
UNSAFE.putShort(newAddr + MIN_OFFSET, oMin);
UNSAFE.putShort(newAddr + MAX_OFFSET, oMax);
UNSAFE.putInt(newAddr + COUNT_OFFSET, oCount);
UNSAFE.putLong(newAddr + SUM_OFFSET, oSum);
}
newMap.addEntry(keyHash, keyAddress, keyLength, measurement);
this.mapStart = newMap.mapStart;
this.mapEnd = newMap.mapEnd;
this.capacity = newMap.capacity;
}
private static void initEntry(final long entry, final long keyAddress, final short keyLength, final short measurement, final int keyHash) {
UNSAFE.copyMemory(keyAddress, entry, keyLength);
UNSAFE.putInt(entry + HASH_OFFSET, keyHash);
UNSAFE.putShort(entry + KEY_LENGTH_OFFSET, keyLength);
UNSAFE.putShort(entry + MIN_OFFSET, Short.MAX_VALUE);
UNSAFE.putShort(entry + MAX_OFFSET, Short.MIN_VALUE);
updateEntry(entry, measurement);
}
private static void updateEntry(final long entry, final short measurement) {
UNSAFE.putShort(entry + MIN_OFFSET,
(short) Math.min(UNSAFE.getShort(entry + MIN_OFFSET), measurement));
UNSAFE.putShort(entry + MAX_OFFSET,
(short) Math.max(UNSAFE.getShort(entry + MAX_OFFSET), measurement));
UNSAFE.putInt(entry + COUNT_OFFSET,
UNSAFE.getInt(entry + COUNT_OFFSET) + 1);
UNSAFE.putLong(entry + SUM_OFFSET,
UNSAFE.getLong(entry + SUM_OFFSET) + measurement);
}
}
private static boolean keysEqual(long key1Address, long key2Address, final int keyLength) {
// credit: abeobk
long xsum = 0;
int n = keyLength & 0xF8;
for (int i = 0; i < n; i += 8) {
xsum |= (UNSAFE.getLong(key1Address + i) ^ UNSAFE.getLong(key2Address + i));
}
return xsum == 0;
}
}