improved artsiomkorzun solution (#321)

This commit is contained in:
Artsiom Korzun 2024-01-11 21:08:15 +01:00 committed by GitHub
parent 1a82c77026
commit 8ef8cd2b17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 139 additions and 176 deletions

View File

@ -15,5 +15,5 @@
# limitations under the License. # limitations under the License.
# #
JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC" JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_artsiomkorzun java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_artsiomkorzun

View File

@ -35,7 +35,7 @@ public class CalculateAverage_artsiomkorzun {
private static final MemorySegment MAPPED_FILE = map(FILE); private static final MemorySegment MAPPED_FILE = map(FILE);
private static final int PARALLELISM = Runtime.getRuntime().availableProcessors(); private static final int PARALLELISM = Runtime.getRuntime().availableProcessors();
private static final int SEGMENT_SIZE = 16 * 1024 * 1024; private static final int SEGMENT_SIZE = 32 * 1024 * 1024;
private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE); private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE);
private static final int SEGMENT_OVERLAP = 1024; private static final int SEGMENT_OVERLAP = 1024;
private static final long COMMA_PATTERN = pattern(';'); private static final long COMMA_PATTERN = pattern(';');
@ -100,16 +100,6 @@ public class CalculateAverage_artsiomkorzun {
return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56); return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56);
} }
private static long getLongBigEndian(long address) {
long value = UNSAFE.getLong(address);
if (BYTE_ORDER == ByteOrder.LITTLE_ENDIAN) {
value = Long.reverseBytes(value);
}
return value;
}
private static long getLongLittleEndian(long address) { private static long getLongLittleEndian(long address) {
long value = UNSAFE.getLong(address); long value = UNSAFE.getLong(address);
@ -144,98 +134,80 @@ public class CalculateAverage_artsiomkorzun {
return Math.round(v) / 10.0; return Math.round(v) / 10.0;
} }
private static class Row {
long address;
int length;
int hash;
int value;
}
private record Aggregate(int min, int max, long sum, int cnt) { private record Aggregate(int min, int max, long sum, int cnt) {
} }
private static class Aggregates { private static class Aggregates {
private static final int SIZE = 16 * 1024; private static final int ENTRIES = 64 * 1024;
private static final int SIZE = 32 * ENTRIES;
private final long pointer; private final long pointer;
public Aggregates() { public Aggregates() {
int size = 32 * SIZE; long address = UNSAFE.allocateMemory(SIZE + 8096);
long address = UNSAFE.allocateMemory(size + 8096);
pointer = (address + 4095) & (~4095); pointer = (address + 4095) & (~4095);
UNSAFE.setMemory(pointer, size, (byte) 0); UNSAFE.setMemory(pointer, SIZE, (byte) 0);
long word = pack(Short.MAX_VALUE, Short.MIN_VALUE, 0);
for (int i = 0; i < SIZE; i++) {
long entry = pointer + 32 * i;
UNSAFE.putLong(entry + 24, word);
}
} }
public void add(Row row) { public void add(long reference, int length, int hash, int value) {
long index = index(row.hash); for (int offset = offset(hash);; offset = next(offset)) {
long header = ((long) row.hash << 32) | (row.length); long address = pointer + offset;
long ref = UNSAFE.getLong(address);
while (true) { if (ref == 0) {
long address = pointer + (index << 5); alloc(reference, length, hash, value, address);
long head = UNSAFE.getLong(address);
long ref = UNSAFE.getLong(address + 8);
boolean isHit = (head == 0) || (head == header && equal(ref, row.address, row.length));
if (isHit) {
long sum = UNSAFE.getLong(address + 16) + row.value;
long word = UNSAFE.getLong(address + 24);
int min = Math.min(min(word), row.value);
int max = Math.max(max(word), row.value);
int cnt = cnt(word) + 1;
UNSAFE.putLong(address, header);
UNSAFE.putLong(address + 8, row.address);
UNSAFE.putLong(address + 16, sum);
UNSAFE.putLong(address + 24, pack(min, max, cnt));
break; break;
} }
index = (index + 1) & (SIZE - 1); if (equal(ref, reference, length)) {
long sum = UNSAFE.getLong(address + 16) + value;
int cnt = UNSAFE.getInt(address + 24) + 1;
short min = (short) Math.min(UNSAFE.getShort(address + 28), value);
short max = (short) Math.max(UNSAFE.getShort(address + 30), value);
UNSAFE.putLong(address + 16, sum);
UNSAFE.putInt(address + 24, cnt);
UNSAFE.putShort(address + 28, min);
UNSAFE.putShort(address + 30, max);
break;
}
} }
} }
public void merge(Aggregates rights) { public void merge(Aggregates rights) {
for (int rightIndex = 0; rightIndex < SIZE; rightIndex++) { for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 32) {
long rightAddress = rights.pointer + (rightIndex << 5); long rightAddress = rights.pointer + rightOffset;
long header = UNSAFE.getLong(rightAddress); long reference = UNSAFE.getLong(rightAddress);
long reference = UNSAFE.getLong(rightAddress + 8);
if (header == 0) { if (reference == 0) {
continue; continue;
} }
int hash = (int) (header >>> 32); int hash = UNSAFE.getInt(rightAddress + 8);
int length = (int) (header); int length = UNSAFE.getInt(rightAddress + 12);
long index = index(hash);
while (true) { for (int offset = offset(hash);; offset = next(offset)) {
long address = pointer + (index << 5); long address = pointer + offset;
long head = UNSAFE.getLong(address); long ref = UNSAFE.getLong(address);
long ref = UNSAFE.getLong(address + 8);
boolean isHit = (head == 0) || (head == header && equal(ref, reference, length));
if (isHit) { if (ref == 0) {
long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16); UNSAFE.copyMemory(rightAddress, address, 32);
long left = UNSAFE.getLong(address + 24);
long right = UNSAFE.getLong(rightAddress + 24);
int min = Math.min(min(left), min(right));
int max = Math.max(max(left), max(right));
int cnt = cnt(left) + cnt(right);
UNSAFE.putLong(address, header);
UNSAFE.putLong(address + 8, reference);
UNSAFE.putLong(address + 16, sum);
UNSAFE.putLong(address + 24, pack(min, max, cnt));
break; break;
} }
index = (index + 1) & (SIZE - 1); if (equal(ref, reference, length)) {
long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16);
int cnt = UNSAFE.getInt(address + 24) + UNSAFE.getInt(rightAddress + 24);
short min = (short) Math.min(UNSAFE.getShort(address + 28), UNSAFE.getShort(rightAddress + 28));
short max = (short) Math.max(UNSAFE.getShort(address + 30), UNSAFE.getShort(rightAddress + 30));
UNSAFE.putLong(address + 16, sum);
UNSAFE.putInt(address + 24, cnt);
UNSAFE.putShort(address + 28, min);
UNSAFE.putShort(address + 30, max);
break;
}
} }
} }
} }
@ -243,68 +215,64 @@ public class CalculateAverage_artsiomkorzun {
public Map<String, Aggregate> aggregate() { public Map<String, Aggregate> aggregate() {
TreeMap<String, Aggregate> set = new TreeMap<>(); TreeMap<String, Aggregate> set = new TreeMap<>();
for (int index = 0; index < SIZE; index++) { for (int offset = 0; offset < SIZE; offset += 32) {
long address = pointer + (index << 5); long address = pointer + offset;
long head = UNSAFE.getLong(address); long ref = UNSAFE.getLong(address);
long ref = UNSAFE.getLong(address + 8);
if (head == 0) { if (ref != 0) {
continue; int length = UNSAFE.getInt(address + 12) - 1;
byte[] array = new byte[length];
UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
String key = new String(array);
long sum = UNSAFE.getLong(address + 16);
int cnt = UNSAFE.getInt(address + 24);
short min = UNSAFE.getShort(address + 28);
short max = UNSAFE.getShort(address + 30);
Aggregate aggregate = new Aggregate(min, max, sum, cnt);
set.put(key, aggregate);
} }
int length = (int) (head);
byte[] array = new byte[length];
UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
String key = new String(array);
long sum = UNSAFE.getLong(address + 16);
long word = UNSAFE.getLong(address + 24);
Aggregate aggregate = new Aggregate(min(word), max(word), sum, cnt(word));
set.put(key, aggregate);
} }
return set; return set;
} }
private static long pack(int min, int max, int cnt) { private static void alloc(long reference, int length, int hash, int value, long address) {
return ((long) min << 48) | (((long) max & 0xFFFF) << 32) | cnt; UNSAFE.putLong(address, reference);
UNSAFE.putInt(address + 8, hash);
UNSAFE.putInt(address + 12, length);
UNSAFE.putLong(address + 16, value);
UNSAFE.putInt(address + 24, 1);
UNSAFE.putShort(address + 28, (short) value);
UNSAFE.putShort(address + 30, (short) value);
} }
private static int cnt(long word) { private static int offset(int hash) {
return (int) word; return ((hash) & (ENTRIES - 1)) << 5;
} }
private static int max(long word) { private static int next(int prev) {
return (short) (word >>> 32); return (prev + 32) & (SIZE - 1);
}
private static int min(long word) {
return (short) (word >>> 48);
}
private static long index(int hash) {
return (hash ^ (hash >> 16)) & (SIZE - 1);
} }
private static boolean equal(long leftAddress, long rightAddress, int length) { private static boolean equal(long leftAddress, long rightAddress, int length) {
int index = 0;
while (length > 8) { while (length > 8) {
long left = UNSAFE.getLong(leftAddress + index); long left = UNSAFE.getLong(leftAddress);
long right = UNSAFE.getLong(rightAddress + index); long right = UNSAFE.getLong(rightAddress);
if (left != right) { if (left != right) {
return false; return false;
} }
leftAddress += 8;
rightAddress += 8;
length -= 8; length -= 8;
index += 8;
} }
int shift = 64 - (length << 3); int shift = (8 - length) << 3;
long left = getLongBigEndian(leftAddress + index) >>> shift; long left = getLongLittleEndian(leftAddress) << shift;
long right = getLongBigEndian(rightAddress + index) >>> shift; long right = getLongLittleEndian(rightAddress) << shift;
return (left == right); return (left == right);
} }
} }
@ -323,10 +291,18 @@ public class CalculateAverage_artsiomkorzun {
@Override @Override
public void run() { public void run() {
Aggregates aggregates = new Aggregates(); Aggregates aggregates = new Aggregates();
Row row = new Row();
for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) { for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) {
aggregate(aggregates, row, segment); long position = (long) SEGMENT_SIZE * segment;
int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, MAPPED_FILE.byteSize() - position);
long address = MAPPED_FILE.address() + position;
long limit = address + Math.min(SEGMENT_SIZE, size - 1);
if (segment > 0) {
address = next(address);
}
aggregate(aggregates, address, limit);
} }
while (!result.compareAndSet(null, aggregates)) { while (!result.compareAndSet(null, aggregates)) {
@ -338,75 +314,62 @@ public class CalculateAverage_artsiomkorzun {
} }
} }
private static void aggregate(Aggregates aggregates, Row row, int segment) { private static void aggregate(Aggregates aggregates, long position, long limit) {
long position = (long) SEGMENT_SIZE * segment; // this parsing can produce seg fault at page boundaries
int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, MAPPED_FILE.byteSize() - position); // e.g. file size is 4096 and the last entry is X=0.0, which is less than 8 bytes
long address = MAPPED_FILE.address() + position; // as a result a read will be split across pages, where one of them is not mapped
long limit = address + Math.min(SEGMENT_SIZE, size - 1); // but for some reason it works on my machine, leaving to investigate
if (segment > 0) { for (long start = position, hash = 0; position <= limit;) {
address = next(address); int length; // idea: royvanrijn, explanation: https://richardstartin.github.io/posts/finding-bytes
} {
long word = getLongLittleEndian(position);
long match = word ^ COMMA_PATTERN;
long mask = (match - 0x0101010101010101L) & ~match & 0x8080808080808080L;
while (address <= limit) { if (mask == 0) {
// this parsing can produce seg fault at page boundaries hash ^= word;
// e.g. file size is 4096 and the last entry is X=0.0, which is less than 8 bytes position += 8;
// as a result a read will be split across pages, where one of them is not mapped continue;
// but for some reason it works on my machine, leaving to investigate }
address = parseKey(address, row);
address = parseValue(address, row);
aggregates.add(row);
}
}
private static long next(long address) { int bit = Long.numberOfTrailingZeros(mask);
while (UNSAFE.getByte(address++) != '\n') { position += (bit >>> 3) + 1; // +sep
// continue hash ^= (word << (69 - bit));
} length = (int) (position - start);
return address;
}
// idea: royvanrijn
// explanation: https://richardstartin.github.io/posts/finding-bytes
private static long parseKey(long address, Row row) {
int length = 0;
long hash = 0;
long word;
while (true) {
word = getLongLittleEndian(address + length);
long match = word ^ COMMA_PATTERN;
long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L;
if (mask == 0) {
hash = 71 * hash + word;
length += 8;
continue;
} }
int bit = Long.numberOfTrailingZeros(mask); int value; // idea: merykitty
length += (bit >>> 3); {
hash = 71 * hash + (word & (0x00FFFFFFFFFFFFFFL >>> (63 - bit))); long word = getLongLittleEndian(position);
long inverted = ~word;
int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS);
long signed = (inverted << 59) >> 63;
long mask = ~(signed & 0xFF);
long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L;
long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
value = (int) ((abs ^ signed) - signed);
position += (dot >> 3) + 3;
}
row.address = address; aggregates.add(start, length, mix(hash), value);
row.length = length;
row.hash = Long.hashCode(hash);
return address + length + 1; start = position;
hash = 0;
} }
} }
// idea: merykitty private static long next(long position) {
private static long parseValue(long address, Row row) { while (UNSAFE.getByte(position++) != '\n') {
long word = getLongLittleEndian(address); // continue
long inverted = ~word; }
int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); return position;
long signed = (inverted << 59) >> 63; }
long mask = ~(signed & 0xFF);
long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; private static int mix(long x) {
long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; long h = x * -7046029254386353131L;
row.value = (int) ((abs ^ signed) - signed); h ^= h >>> 32;
return address + (dot >> 3) + 3; return (int) (h ^ h >>> 16);
} }
} }
} }