branchy version (#408)

This commit is contained in:
Artsiom Korzun 2024-01-15 19:57:34 +01:00 committed by GitHub
parent ca075b66f2
commit 987da54906
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,7 +20,6 @@ import sun.misc.Unsafe;
import java.lang.foreign.Arena; import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment; import java.lang.foreign.MemorySegment;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.StandardOpenOption; import java.nio.file.StandardOpenOption;
@ -38,11 +37,10 @@ public class CalculateAverage_artsiomkorzun {
private static final int SEGMENT_SIZE = 32 * 1024 * 1024; private static final int SEGMENT_SIZE = 32 * 1024 * 1024;
private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE); private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE);
private static final int SEGMENT_OVERLAP = 1024; private static final int SEGMENT_OVERLAP = 1024;
private static final long COMMA_PATTERN = pattern(';'); private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL;
private static final long DOT_BITS = 0x10101000; private static final long DOT_BITS = 0x10101000;
private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder();
private static final Unsafe UNSAFE; private static final Unsafe UNSAFE;
static { static {
@ -95,19 +93,15 @@ public class CalculateAverage_artsiomkorzun {
} }
} }
private static long pattern(char c) { private static long word(long address) {
long b = c & 0xFFL; return UNSAFE.getLong(address);
return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56); /*
} * if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
* value = Long.reverseBytes(value);
private static long getLongLittleEndian(long address) { * }
long value = UNSAFE.getLong(address); *
* return value;
if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) { */
value = Long.reverseBytes(value);
}
return value;
} }
private static String text(Map<String, Aggregate> aggregates) { private static String text(Map<String, Aggregate> aggregates) {
@ -140,7 +134,7 @@ public class CalculateAverage_artsiomkorzun {
private static class Aggregates { private static class Aggregates {
private static final int ENTRIES = 64 * 1024; private static final int ENTRIES = 64 * 1024;
private static final int SIZE = 32 * ENTRIES; private static final int SIZE = 128 * ENTRIES;
private final long pointer; private final long pointer;
@ -150,62 +144,82 @@ public class CalculateAverage_artsiomkorzun {
UNSAFE.setMemory(pointer, SIZE, (byte) 0); UNSAFE.setMemory(pointer, SIZE, (byte) 0);
} }
public void add(long reference, int length, int hash, int value) { public long find(long word, int hash) {
long address = pointer + offset(hash);
long w = word(address + 24);
return (w == word) ? address : 0;
}
public long find(long word1, long word2, int hash) {
long address = pointer + offset(hash);
long w1 = word(address + 24);
long w2 = word(address + 32);
return (word1 == w1) && (word2 == w2) ? address : 0;
}
public long put(long reference, long word, int length, int hash) {
for (int offset = offset(hash);; offset = next(offset)) { for (int offset = offset(hash);; offset = next(offset)) {
long address = pointer + offset; long address = pointer + offset;
long ref = UNSAFE.getLong(address); if (equal(reference, word, address + 24, length)) {
return address;
if (ref == 0) {
alloc(reference, length, hash, value, address);
break;
} }
if (equal(ref, reference, length)) { int len = UNSAFE.getInt(address);
long sum = UNSAFE.getLong(address + 16) + value; if (len == 0) {
int cnt = UNSAFE.getInt(address + 24) + 1; alloc(reference, length, hash, address);
short min = (short) Math.min(UNSAFE.getShort(address + 28), value); return address;
short max = (short) Math.max(UNSAFE.getShort(address + 30), value);
UNSAFE.putLong(address + 16, sum);
UNSAFE.putInt(address + 24, cnt);
UNSAFE.putShort(address + 28, min);
UNSAFE.putShort(address + 30, max);
break;
} }
} }
} }
public static void update(long address, int value) {
long sum = UNSAFE.getLong(address + 8) + value;
int cnt = UNSAFE.getInt(address + 16) + 1;
short min = UNSAFE.getShort(address + 20);
short max = UNSAFE.getShort(address + 22);
UNSAFE.putLong(address + 8, sum);
UNSAFE.putInt(address + 16, cnt);
if (value < min) {
UNSAFE.putShort(address + 20, (short) value);
}
if (value > max) {
UNSAFE.putShort(address + 22, (short) value);
}
}
public void merge(Aggregates rights) { public void merge(Aggregates rights) {
for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 32) { for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 128) {
long rightAddress = rights.pointer + rightOffset; long rightAddress = rights.pointer + rightOffset;
long reference = UNSAFE.getLong(rightAddress); int length = UNSAFE.getInt(rightAddress);
if (reference == 0) { if (length == 0) {
continue; continue;
} }
int hash = UNSAFE.getInt(rightAddress + 8); int hash = UNSAFE.getInt(rightAddress + 4);
int length = UNSAFE.getInt(rightAddress + 12);
for (int offset = offset(hash);; offset = next(offset)) { for (int offset = offset(hash);; offset = next(offset)) {
long address = pointer + offset; long address = pointer + offset;
long ref = UNSAFE.getLong(address); int len = UNSAFE.getInt(address);
if (ref == 0) { if (len == 0) {
UNSAFE.copyMemory(rightAddress, address, 32); UNSAFE.copyMemory(rightAddress, address, 24 + length);
break; break;
} }
if (equal(ref, reference, length)) { if (len == length && equal(address + 24, rightAddress + 24, length)) {
long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16); long sum = UNSAFE.getLong(address + 8) + UNSAFE.getLong(rightAddress + 8);
int cnt = UNSAFE.getInt(address + 24) + UNSAFE.getInt(rightAddress + 24); int cnt = UNSAFE.getInt(address + 16) + UNSAFE.getInt(rightAddress + 16);
short min = (short) Math.min(UNSAFE.getShort(address + 28), UNSAFE.getShort(rightAddress + 28)); short min = (short) Math.min(UNSAFE.getShort(address + 20), UNSAFE.getShort(rightAddress + 20));
short max = (short) Math.max(UNSAFE.getShort(address + 30), UNSAFE.getShort(rightAddress + 30)); short max = (short) Math.max(UNSAFE.getShort(address + 22), UNSAFE.getShort(rightAddress + 22));
UNSAFE.putLong(address + 16, sum); UNSAFE.putLong(address + 8, sum);
UNSAFE.putInt(address + 24, cnt); UNSAFE.putInt(address + 16, cnt);
UNSAFE.putShort(address + 28, min); UNSAFE.putShort(address + 20, min);
UNSAFE.putShort(address + 30, max); UNSAFE.putShort(address + 22, max);
break; break;
} }
} }
@ -215,20 +229,19 @@ public class CalculateAverage_artsiomkorzun {
public Map<String, Aggregate> aggregate() { public Map<String, Aggregate> aggregate() {
TreeMap<String, Aggregate> set = new TreeMap<>(); TreeMap<String, Aggregate> set = new TreeMap<>();
for (int offset = 0; offset < SIZE; offset += 32) { for (int offset = 0; offset < SIZE; offset += 128) {
long address = pointer + offset; long address = pointer + offset;
long ref = UNSAFE.getLong(address); int length = UNSAFE.getInt(address);
if (ref != 0) { if (length != 0) {
int length = UNSAFE.getInt(address + 12) - 1;
byte[] array = new byte[length]; byte[] array = new byte[length];
UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); UNSAFE.copyMemory(null, address + 24, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
String key = new String(array); String key = new String(array);
long sum = UNSAFE.getLong(address + 16); long sum = UNSAFE.getLong(address + 8);
int cnt = UNSAFE.getInt(address + 24); int cnt = UNSAFE.getInt(address + 16);
short min = UNSAFE.getShort(address + 28); short min = UNSAFE.getShort(address + 20);
short max = UNSAFE.getShort(address + 30); short max = UNSAFE.getShort(address + 22);
Aggregate aggregate = new Aggregate(min, max, sum, cnt); Aggregate aggregate = new Aggregate(min, max, sum, cnt);
set.put(key, aggregate); set.put(key, aggregate);
@ -238,26 +251,24 @@ public class CalculateAverage_artsiomkorzun {
return set; return set;
} }
private static void alloc(long reference, int length, int hash, int value, long address) { private static void alloc(long reference, int length, int hash, long address) {
UNSAFE.putLong(address, reference); UNSAFE.putInt(address, length);
UNSAFE.putInt(address + 8, hash); UNSAFE.putInt(address + 4, hash);
UNSAFE.putInt(address + 12, length); UNSAFE.putShort(address + 20, Short.MAX_VALUE);
UNSAFE.putLong(address + 16, value); UNSAFE.putShort(address + 22, Short.MIN_VALUE);
UNSAFE.putInt(address + 24, 1); UNSAFE.copyMemory(reference, address + 24, length);
UNSAFE.putShort(address + 28, (short) value);
UNSAFE.putShort(address + 30, (short) value);
} }
private static int offset(int hash) { private static int offset(int hash) {
return ((hash) & (ENTRIES - 1)) << 5; return ((hash) & (ENTRIES - 1)) << 7;
} }
private static int next(int prev) { private static int next(int prev) {
return (prev + 32) & (SIZE - 1); return (prev + 128) & (SIZE - 1);
} }
private static boolean equal(long leftAddress, long rightAddress, int length) { private static boolean equal(long leftAddress, long leftWord, long rightAddress, int length) {
while (length > 8) { while (length >= 8) {
long left = UNSAFE.getLong(leftAddress); long left = UNSAFE.getLong(leftAddress);
long right = UNSAFE.getLong(rightAddress); long right = UNSAFE.getLong(rightAddress);
@ -270,10 +281,24 @@ public class CalculateAverage_artsiomkorzun {
length -= 8; length -= 8;
} }
int shift = (8 - length) << 3; return leftWord == word(rightAddress);
long left = getLongLittleEndian(leftAddress) << shift; }
long right = getLongLittleEndian(rightAddress) << shift;
return (left == right); private static boolean equal(long leftAddress, long rightAddress, int length) {
do {
long left = UNSAFE.getLong(leftAddress);
long right = UNSAFE.getLong(rightAddress);
if (left != right) {
return false;
}
leftAddress += 8;
rightAddress += 8;
length -= 8;
} while (length > 0);
return true;
} }
} }
@ -320,43 +345,87 @@ public class CalculateAverage_artsiomkorzun {
// as a result a read will be split across pages, where one of them is not mapped // as a result a read will be split across pages, where one of them is not mapped
// but for some reason it works on my machine, leaving to investigate // but for some reason it works on my machine, leaving to investigate
for (long start = position, hash = 0; position <= limit;) { while (position <= limit) { // branchy version, credit: thomaswue
int length; // idea: royvanrijn, explanation: https://richardstartin.github.io/posts/finding-bytes int length;
{ int hash;
long word = getLongLittleEndian(position);
long match = word ^ COMMA_PATTERN;
long mask = (match - 0x0101010101010101L) & ~match & 0x8080808080808080L;
if (mask == 0) { long ptr = 0;
hash ^= word; long word = word(position);
position += 8; long separator = separator(word);
if (separator != 0) {
length = length(separator);
word = mask(word, separator);
hash = mix(word);
ptr = aggregates.find(word, hash);
}
else {
long word0 = word;
word = word(position + 8);
separator = separator(word);
if (separator != 0) {
length = length(separator) + 8;
word = mask(word, separator);
hash = mix(word ^ word0);
ptr = aggregates.find(word0, word, hash);
}
else {
length = 16;
long h = word ^ word0;
while (true) {
word = word(position + length);
separator = separator(word);
if (separator == 0) {
length += 8;
h ^= word;
continue; continue;
} }
int bit = Long.numberOfTrailingZeros(mask); length += length(separator);
position += (bit >>> 3) + 1; // +sep word = mask(word, separator);
hash ^= (word << (69 - bit)); hash = mix(h ^ word);
length = (int) (position - start); break;
}
}
} }
int value; // idea: merykitty if (ptr == 0) {
{ ptr = aggregates.put(position, word, length, hash);
long word = getLongLittleEndian(position); }
position = update(ptr, position + length + 1);
}
}
private static long update(long ptr, long position) {
// idea: merykitty
long word = word(position);
long inverted = ~word; long inverted = ~word;
int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS);
long signed = (inverted << 59) >> 63; long signed = (inverted << 59) >> 63;
long mask = ~(signed & 0xFF); long mask = ~(signed & 0xFF);
long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L;
long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
value = (int) ((abs ^ signed) - signed); int value = (int) ((abs ^ signed) - signed);
position += (dot >> 3) + 3;
Aggregates.update(ptr, value);
return position + (dot >> 3) + 3;
} }
aggregates.add(start, length, mix(hash), value); private static long separator(long word) {
long match = word ^ COMMA_PATTERN;
start = position; return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);
hash = 0;
} }
private static long mask(long word, long separator) {
return word & ((separator >>> 7) - 1) & 0x00FFFFFFFFFFFFFFL;
}
private static int length(long separator) {
return Long.numberOfTrailingZeros(separator) >>> 3;
} }
private static long next(long position) { private static long next(long position) {