branchy version (#408)

This commit is contained in:
Artsiom Korzun 2024-01-15 19:57:34 +01:00 committed by GitHub
parent ca075b66f2
commit 987da54906
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,7 +20,6 @@ import sun.misc.Unsafe;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.reflect.Field;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
@ -38,11 +37,10 @@ public class CalculateAverage_artsiomkorzun {
private static final int SEGMENT_SIZE = 32 * 1024 * 1024;
private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE);
private static final int SEGMENT_OVERLAP = 1024;
private static final long COMMA_PATTERN = pattern(';');
private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL;
private static final long DOT_BITS = 0x10101000;
private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder();
private static final Unsafe UNSAFE;
static {
@ -95,19 +93,15 @@ public class CalculateAverage_artsiomkorzun {
}
}
private static long pattern(char c) {
long b = c & 0xFFL;
return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56);
}
private static long getLongLittleEndian(long address) {
long value = UNSAFE.getLong(address);
if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
value = Long.reverseBytes(value);
}
return value;
private static long word(long address) {
return UNSAFE.getLong(address);
/*
* if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
* value = Long.reverseBytes(value);
* }
*
* return value;
*/
}
private static String text(Map<String, Aggregate> aggregates) {
@ -140,7 +134,7 @@ public class CalculateAverage_artsiomkorzun {
private static class Aggregates {
private static final int ENTRIES = 64 * 1024;
private static final int SIZE = 32 * ENTRIES;
private static final int SIZE = 128 * ENTRIES;
private final long pointer;
@ -150,62 +144,82 @@ public class CalculateAverage_artsiomkorzun {
UNSAFE.setMemory(pointer, SIZE, (byte) 0);
}
public void add(long reference, int length, int hash, int value) {
public long find(long word, int hash) {
long address = pointer + offset(hash);
long w = word(address + 24);
return (w == word) ? address : 0;
}
public long find(long word1, long word2, int hash) {
long address = pointer + offset(hash);
long w1 = word(address + 24);
long w2 = word(address + 32);
return (word1 == w1) && (word2 == w2) ? address : 0;
}
public long put(long reference, long word, int length, int hash) {
for (int offset = offset(hash);; offset = next(offset)) {
long address = pointer + offset;
long ref = UNSAFE.getLong(address);
if (ref == 0) {
alloc(reference, length, hash, value, address);
break;
if (equal(reference, word, address + 24, length)) {
return address;
}
if (equal(ref, reference, length)) {
long sum = UNSAFE.getLong(address + 16) + value;
int cnt = UNSAFE.getInt(address + 24) + 1;
short min = (short) Math.min(UNSAFE.getShort(address + 28), value);
short max = (short) Math.max(UNSAFE.getShort(address + 30), value);
UNSAFE.putLong(address + 16, sum);
UNSAFE.putInt(address + 24, cnt);
UNSAFE.putShort(address + 28, min);
UNSAFE.putShort(address + 30, max);
break;
int len = UNSAFE.getInt(address);
if (len == 0) {
alloc(reference, length, hash, address);
return address;
}
}
}
public void merge(Aggregates rights) {
for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 32) {
long rightAddress = rights.pointer + rightOffset;
long reference = UNSAFE.getLong(rightAddress);
public static void update(long address, int value) {
long sum = UNSAFE.getLong(address + 8) + value;
int cnt = UNSAFE.getInt(address + 16) + 1;
short min = UNSAFE.getShort(address + 20);
short max = UNSAFE.getShort(address + 22);
if (reference == 0) {
UNSAFE.putLong(address + 8, sum);
UNSAFE.putInt(address + 16, cnt);
if (value < min) {
UNSAFE.putShort(address + 20, (short) value);
}
if (value > max) {
UNSAFE.putShort(address + 22, (short) value);
}
}
public void merge(Aggregates rights) {
for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 128) {
long rightAddress = rights.pointer + rightOffset;
int length = UNSAFE.getInt(rightAddress);
if (length == 0) {
continue;
}
int hash = UNSAFE.getInt(rightAddress + 8);
int length = UNSAFE.getInt(rightAddress + 12);
int hash = UNSAFE.getInt(rightAddress + 4);
for (int offset = offset(hash);; offset = next(offset)) {
long address = pointer + offset;
long ref = UNSAFE.getLong(address);
int len = UNSAFE.getInt(address);
if (ref == 0) {
UNSAFE.copyMemory(rightAddress, address, 32);
if (len == 0) {
UNSAFE.copyMemory(rightAddress, address, 24 + length);
break;
}
if (equal(ref, reference, length)) {
long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16);
int cnt = UNSAFE.getInt(address + 24) + UNSAFE.getInt(rightAddress + 24);
short min = (short) Math.min(UNSAFE.getShort(address + 28), UNSAFE.getShort(rightAddress + 28));
short max = (short) Math.max(UNSAFE.getShort(address + 30), UNSAFE.getShort(rightAddress + 30));
if (len == length && equal(address + 24, rightAddress + 24, length)) {
long sum = UNSAFE.getLong(address + 8) + UNSAFE.getLong(rightAddress + 8);
int cnt = UNSAFE.getInt(address + 16) + UNSAFE.getInt(rightAddress + 16);
short min = (short) Math.min(UNSAFE.getShort(address + 20), UNSAFE.getShort(rightAddress + 20));
short max = (short) Math.max(UNSAFE.getShort(address + 22), UNSAFE.getShort(rightAddress + 22));
UNSAFE.putLong(address + 16, sum);
UNSAFE.putInt(address + 24, cnt);
UNSAFE.putShort(address + 28, min);
UNSAFE.putShort(address + 30, max);
UNSAFE.putLong(address + 8, sum);
UNSAFE.putInt(address + 16, cnt);
UNSAFE.putShort(address + 20, min);
UNSAFE.putShort(address + 22, max);
break;
}
}
@ -215,20 +229,19 @@ public class CalculateAverage_artsiomkorzun {
public Map<String, Aggregate> aggregate() {
TreeMap<String, Aggregate> set = new TreeMap<>();
for (int offset = 0; offset < SIZE; offset += 32) {
for (int offset = 0; offset < SIZE; offset += 128) {
long address = pointer + offset;
long ref = UNSAFE.getLong(address);
int length = UNSAFE.getInt(address);
if (ref != 0) {
int length = UNSAFE.getInt(address + 12) - 1;
if (length != 0) {
byte[] array = new byte[length];
UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
UNSAFE.copyMemory(null, address + 24, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
String key = new String(array);
long sum = UNSAFE.getLong(address + 16);
int cnt = UNSAFE.getInt(address + 24);
short min = UNSAFE.getShort(address + 28);
short max = UNSAFE.getShort(address + 30);
long sum = UNSAFE.getLong(address + 8);
int cnt = UNSAFE.getInt(address + 16);
short min = UNSAFE.getShort(address + 20);
short max = UNSAFE.getShort(address + 22);
Aggregate aggregate = new Aggregate(min, max, sum, cnt);
set.put(key, aggregate);
@ -238,26 +251,24 @@ public class CalculateAverage_artsiomkorzun {
return set;
}
private static void alloc(long reference, int length, int hash, int value, long address) {
UNSAFE.putLong(address, reference);
UNSAFE.putInt(address + 8, hash);
UNSAFE.putInt(address + 12, length);
UNSAFE.putLong(address + 16, value);
UNSAFE.putInt(address + 24, 1);
UNSAFE.putShort(address + 28, (short) value);
UNSAFE.putShort(address + 30, (short) value);
private static void alloc(long reference, int length, int hash, long address) {
UNSAFE.putInt(address, length);
UNSAFE.putInt(address + 4, hash);
UNSAFE.putShort(address + 20, Short.MAX_VALUE);
UNSAFE.putShort(address + 22, Short.MIN_VALUE);
UNSAFE.copyMemory(reference, address + 24, length);
}
private static int offset(int hash) {
return ((hash) & (ENTRIES - 1)) << 5;
return ((hash) & (ENTRIES - 1)) << 7;
}
private static int next(int prev) {
return (prev + 32) & (SIZE - 1);
return (prev + 128) & (SIZE - 1);
}
private static boolean equal(long leftAddress, long rightAddress, int length) {
while (length > 8) {
private static boolean equal(long leftAddress, long leftWord, long rightAddress, int length) {
while (length >= 8) {
long left = UNSAFE.getLong(leftAddress);
long right = UNSAFE.getLong(rightAddress);
@ -270,10 +281,24 @@ public class CalculateAverage_artsiomkorzun {
length -= 8;
}
int shift = (8 - length) << 3;
long left = getLongLittleEndian(leftAddress) << shift;
long right = getLongLittleEndian(rightAddress) << shift;
return (left == right);
return leftWord == word(rightAddress);
}
private static boolean equal(long leftAddress, long rightAddress, int length) {
do {
long left = UNSAFE.getLong(leftAddress);
long right = UNSAFE.getLong(rightAddress);
if (left != right) {
return false;
}
leftAddress += 8;
rightAddress += 8;
length -= 8;
} while (length > 0);
return true;
}
}
@ -320,45 +345,89 @@ public class CalculateAverage_artsiomkorzun {
// as a result a read will be split across pages, where one of them is not mapped
// but for some reason it works on my machine, leaving to investigate
for (long start = position, hash = 0; position <= limit;) {
int length; // idea: royvanrijn, explanation: https://richardstartin.github.io/posts/finding-bytes
{
long word = getLongLittleEndian(position);
long match = word ^ COMMA_PATTERN;
long mask = (match - 0x0101010101010101L) & ~match & 0x8080808080808080L;
while (position <= limit) { // branchy version, credit: thomaswue
int length;
int hash;
if (mask == 0) {
hash ^= word;
position += 8;
continue;
long ptr = 0;
long word = word(position);
long separator = separator(word);
if (separator != 0) {
length = length(separator);
word = mask(word, separator);
hash = mix(word);
ptr = aggregates.find(word, hash);
}
else {
long word0 = word;
word = word(position + 8);
separator = separator(word);
if (separator != 0) {
length = length(separator) + 8;
word = mask(word, separator);
hash = mix(word ^ word0);
ptr = aggregates.find(word0, word, hash);
}
else {
length = 16;
long h = word ^ word0;
int bit = Long.numberOfTrailingZeros(mask);
position += (bit >>> 3) + 1; // +sep
hash ^= (word << (69 - bit));
length = (int) (position - start);
while (true) {
word = word(position + length);
separator = separator(word);
if (separator == 0) {
length += 8;
h ^= word;
continue;
}
length += length(separator);
word = mask(word, separator);
hash = mix(h ^ word);
break;
}
}
}
int value; // idea: merykitty
{
long word = getLongLittleEndian(position);
long inverted = ~word;
int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS);
long signed = (inverted << 59) >> 63;
long mask = ~(signed & 0xFF);
long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L;
long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
value = (int) ((abs ^ signed) - signed);
position += (dot >> 3) + 3;
if (ptr == 0) {
ptr = aggregates.put(position, word, length, hash);
}
aggregates.add(start, length, mix(hash), value);
start = position;
hash = 0;
position = update(ptr, position + length + 1);
}
}
private static long update(long ptr, long position) {
// idea: merykitty
long word = word(position);
long inverted = ~word;
int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS);
long signed = (inverted << 59) >> 63;
long mask = ~(signed & 0xFF);
long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L;
long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
int value = (int) ((abs ^ signed) - signed);
Aggregates.update(ptr, value);
return position + (dot >> 3) + 3;
}
private static long separator(long word) {
long match = word ^ COMMA_PATTERN;
return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);
}
private static long mask(long word, long separator) {
return word & ((separator >>> 7) - 1) & 0x00FFFFFFFFFFFFFFL;
}
private static int length(long separator) {
return Long.numberOfTrailingZeros(separator) >>> 3;
}
private static long next(long position) {
while (UNSAFE.getByte(position++) != '\n') {
// continue