loop similar to thomas (#634)

This commit is contained in:
Artsiom Korzun 2024-01-29 20:36:25 +01:00 committed by GitHub
parent 1eaf8791c1
commit 5ba094c8fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -33,8 +33,9 @@ import java.util.concurrent.atomic.AtomicReference;
public class CalculateAverage_artsiomkorzun { public class CalculateAverage_artsiomkorzun {
private static final Path FILE = Path.of("./measurements.txt"); private static final Path FILE = Path.of("./measurements.txt");
private static final long SEGMENT_SIZE = 4 * 1024 * 1024; private static final long SEGMENT_SIZE = 2 * 1024 * 1024;
private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL; private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL;
private static final long LINE_PATTERN = 0x0A0A0A0A0A0A0A0AL;
private static final long DOT_BITS = 0x10101000; private static final long DOT_BITS = 0x10101000;
private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
@ -162,14 +163,14 @@ public class CalculateAverage_artsiomkorzun {
return Math.round(v) / 10.0; return Math.round(v) / 10.0;
} }
private record Aggregate(int min, int max, long sum, int cnt) { private record Aggregate(long min, long max, long sum, long cnt) {
} }
private static class Aggregates { private static class Aggregates {
private static final int ENTRIES = 64 * 1024; private static final long ENTRIES = 64 * 1024;
private static final int SIZE = 128 * ENTRIES; private static final long SIZE = 256 * ENTRIES;
private static final int MASK = (ENTRIES - 1) << 7; private static final long MASK = (ENTRIES - 1) << 8;
private final long pointer; private final long pointer;
@ -179,27 +180,27 @@ public class CalculateAverage_artsiomkorzun {
UNSAFE.setMemory(pointer, SIZE, (byte) 0); UNSAFE.setMemory(pointer, SIZE, (byte) 0);
} }
public long find(long word, int hash) { public long find(long word, long hash) {
long address = pointer + offset(hash); long address = pointer + offset(hash);
long w = word(address + 24); long w = word(address + 48);
return (w == word) ? address : 0; return (w == word) ? address : 0;
} }
public long find(long word1, long word2, int hash) { public long find(long word1, long word2, long hash) {
long address = pointer + offset(hash); long address = pointer + offset(hash);
long w1 = word(address + 24); long w1 = word(address + 48);
long w2 = word(address + 32); long w2 = word(address + 56);
return (word1 == w1) && (word2 == w2) ? address : 0; return (word1 == w1) && (word2 == w2) ? address : 0;
} }
public long put(long reference, long word, int length, int hash) { public long put(long reference, long word, long length, long hash) {
for (int offset = offset(hash);; offset = next(offset)) { for (long offset = offset(hash);; offset = next(offset)) {
long address = pointer + offset; long address = pointer + offset;
if (equal(reference, word, address + 24, length)) { if (equal(reference, word, address + 48, length)) {
return address; return address;
} }
int len = UNSAFE.getInt(address); long len = UNSAFE.getLong(address);
if (len == 0) { if (len == 0) {
alloc(reference, length, hash, address); alloc(reference, length, hash, address);
return address; return address;
@ -207,55 +208,55 @@ public class CalculateAverage_artsiomkorzun {
} }
} }
public static void update(long address, int value) { public static void update(long address, long value) {
long sum = UNSAFE.getLong(address + 8) + value; long sum = UNSAFE.getLong(address + 16) + value;
int cnt = UNSAFE.getInt(address + 16) + 1; long cnt = UNSAFE.getLong(address + 24) + 1;
short min = UNSAFE.getShort(address + 20); long min = UNSAFE.getLong(address + 32);
short max = UNSAFE.getShort(address + 22); long max = UNSAFE.getLong(address + 40);
UNSAFE.putLong(address + 8, sum); UNSAFE.putLong(address + 16, sum);
UNSAFE.putInt(address + 16, cnt); UNSAFE.putLong(address + 24, cnt);
if (value < min) { if (value < min) {
UNSAFE.putShort(address + 20, (short) value); UNSAFE.putLong(address + 32, value);
} }
if (value > max) { if (value > max) {
UNSAFE.putShort(address + 22, (short) value); UNSAFE.putLong(address + 40, value);
} }
} }
public void merge(Aggregates rights) { public void merge(Aggregates rights) {
for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 128) { for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 256) {
long rightAddress = rights.pointer + rightOffset; long rightAddress = rights.pointer + rightOffset;
int length = UNSAFE.getInt(rightAddress); long length = UNSAFE.getLong(rightAddress);
if (length == 0) { if (length == 0) {
continue; continue;
} }
int hash = UNSAFE.getInt(rightAddress + 4); long hash = UNSAFE.getLong(rightAddress + 8);
for (int offset = offset(hash);; offset = next(offset)) { for (long offset = offset(hash);; offset = next(offset)) {
long address = pointer + offset; long address = pointer + offset;
if (equal(address + 24, rightAddress + 24, length)) { if (equal(address + 48, rightAddress + 48, length)) {
long sum = UNSAFE.getLong(address + 8) + UNSAFE.getLong(rightAddress + 8); long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16);
int cnt = UNSAFE.getInt(address + 16) + UNSAFE.getInt(rightAddress + 16); long cnt = UNSAFE.getLong(address + 24) + UNSAFE.getLong(rightAddress + 24);
short min = (short) Math.min(UNSAFE.getShort(address + 20), UNSAFE.getShort(rightAddress + 20)); long min = Math.min(UNSAFE.getLong(address + 32), UNSAFE.getLong(rightAddress + 32));
short max = (short) Math.max(UNSAFE.getShort(address + 22), UNSAFE.getShort(rightAddress + 22)); long max = Math.max(UNSAFE.getLong(address + 40), UNSAFE.getLong(rightAddress + 40));
UNSAFE.putLong(address + 8, sum); UNSAFE.putLong(address + 16, sum);
UNSAFE.putInt(address + 16, cnt); UNSAFE.putLong(address + 24, cnt);
UNSAFE.putShort(address + 20, min); UNSAFE.putLong(address + 32, min);
UNSAFE.putShort(address + 22, max); UNSAFE.putLong(address + 40, max);
break; break;
} }
int len = UNSAFE.getInt(address); long len = UNSAFE.getLong(address);
if (len == 0) { if (len == 0) {
UNSAFE.copyMemory(rightAddress, address, length + 24); UNSAFE.copyMemory(rightAddress, address, length + 48);
break; break;
} }
} }
@ -265,19 +266,19 @@ public class CalculateAverage_artsiomkorzun {
public Map<String, Aggregate> aggregate() { public Map<String, Aggregate> aggregate() {
TreeMap<String, Aggregate> set = new TreeMap<>(); TreeMap<String, Aggregate> set = new TreeMap<>();
for (int offset = 0; offset < SIZE; offset += 128) { for (long offset = 0; offset < SIZE; offset += 256) {
long address = pointer + offset; long address = pointer + offset;
int length = UNSAFE.getInt(address); long length = UNSAFE.getLong(address);
if (length != 0) { if (length != 0) {
byte[] array = new byte[length - 1]; byte[] array = new byte[(int) length - 1];
UNSAFE.copyMemory(null, address + 24, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, array.length); UNSAFE.copyMemory(null, address + 48, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, array.length);
String key = new String(array); String key = new String(array);
long sum = UNSAFE.getLong(address + 8); long sum = UNSAFE.getLong(address + 16);
int cnt = UNSAFE.getInt(address + 16); long cnt = UNSAFE.getLong(address + 24);
short min = UNSAFE.getShort(address + 20); long min = UNSAFE.getLong(address + 32);
short max = UNSAFE.getShort(address + 22); long max = UNSAFE.getLong(address + 40);
Aggregate aggregate = new Aggregate(min, max, sum, cnt); Aggregate aggregate = new Aggregate(min, max, sum, cnt);
set.put(key, aggregate); set.put(key, aggregate);
@ -287,23 +288,23 @@ public class CalculateAverage_artsiomkorzun {
return set; return set;
} }
private static void alloc(long reference, int length, int hash, long address) { private static void alloc(long reference, long length, long hash, long address) {
UNSAFE.putInt(address, length); UNSAFE.putLong(address, length);
UNSAFE.putInt(address + 4, hash); UNSAFE.putLong(address + 8, hash);
UNSAFE.putShort(address + 20, Short.MAX_VALUE); UNSAFE.putLong(address + 32, Long.MAX_VALUE);
UNSAFE.putShort(address + 22, Short.MIN_VALUE); UNSAFE.putLong(address + 40, Long.MIN_VALUE);
UNSAFE.copyMemory(reference, address + 24, length); UNSAFE.copyMemory(reference, address + 48, length);
} }
private static int offset(int hash) { private static long offset(long hash) {
return hash & MASK; return hash & MASK;
} }
private static int next(int prev) { private static long next(long prev) {
return (prev + 128) & (SIZE - 1); return (prev + 256) & (SIZE - 1);
} }
private static boolean equal(long leftAddress, long leftWord, long rightAddress, int length) { private static boolean equal(long leftAddress, long leftWord, long rightAddress, long length) {
while (length > 8) { while (length > 8) {
long left = UNSAFE.getLong(leftAddress); long left = UNSAFE.getLong(leftAddress);
long right = UNSAFE.getLong(rightAddress); long right = UNSAFE.getLong(rightAddress);
@ -320,7 +321,7 @@ public class CalculateAverage_artsiomkorzun {
return leftWord == word(rightAddress); return leftWord == word(rightAddress);
} }
private static boolean equal(long leftAddress, long rightAddress, int length) { private static boolean equal(long leftAddress, long rightAddress, long length) {
do { do {
long left = UNSAFE.getLong(leftAddress); long left = UNSAFE.getLong(leftAddress);
long right = UNSAFE.getLong(rightAddress); long right = UNSAFE.getLong(rightAddress);
@ -362,7 +363,7 @@ public class CalculateAverage_artsiomkorzun {
for (int segment; (segment = counter.getAndIncrement()) < segmentCount;) { for (int segment; (segment = counter.getAndIncrement()) < segmentCount;) {
long position = SEGMENT_SIZE * segment; long position = SEGMENT_SIZE * segment;
long size = Math.min(SEGMENT_SIZE, fileSize - position - 1); long size = Math.min(SEGMENT_SIZE + 1, fileSize - position);
long start = fileAddress + position; long start = fileAddress + position;
long end = start + size; long end = start + size;
@ -374,7 +375,55 @@ public class CalculateAverage_artsiomkorzun {
long left = next(start + chunk); long left = next(start + chunk);
long right = next(start + chunk + chunk); long right = next(start + chunk + chunk);
aggregate(aggregates, start, left - 1, left, right - 1, right, end); Chunk chunk1 = new Chunk(start, left);
Chunk chunk2 = new Chunk(left, right);
Chunk chunk3 = new Chunk(right, end);
while (chunk1.has() && chunk2.has() && chunk3.has()) {
long word1 = word(chunk1.position);
long word2 = word(chunk2.position);
long word3 = word(chunk3.position);
long separator1 = separator(word1);
long separator2 = separator(word2);
long separator3 = separator(word3);
long pointer1 = find(aggregates, chunk1, word1, separator1);
long pointer2 = find(aggregates, chunk2, word2, separator2);
long pointer3 = find(aggregates, chunk3, word3, separator3);
long value1 = value(chunk1);
long value2 = value(chunk2);
long value3 = value(chunk3);
Aggregates.update(pointer1, value1);
Aggregates.update(pointer2, value2);
Aggregates.update(pointer3, value3);
}
while (chunk1.has()) {
long word1 = word(chunk1.position);
long separator1 = separator(word1);
long pointer1 = find(aggregates, chunk1, word1, separator1);
long value1 = value(chunk1);
Aggregates.update(pointer1, value1);
}
while (chunk2.has()) {
long word2 = word(chunk2.position);
long separator2 = separator(word2);
long pointer2 = find(aggregates, chunk2, word2, separator2);
long value2 = value(chunk2);
Aggregates.update(pointer2, value2);
}
while (chunk3.has()) {
long word3 = word(chunk3.position);
long separator3 = separator(word3);
long pointer3 = find(aggregates, chunk3, word3, separator3);
long value3 = value(chunk3);
Aggregates.update(pointer3, value3);
}
} }
while (!result.compareAndSet(null, aggregates)) { while (!result.compareAndSet(null, aggregates)) {
@ -387,123 +436,82 @@ public class CalculateAverage_artsiomkorzun {
} }
private static long next(long position) { private static long next(long position) {
while (UNSAFE.getByte(position++) != '\n') { while (true) {
// continue long word = word(position);
} long match = word ^ LINE_PATTERN;
return position; long line = (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);
}
private static void aggregate(Aggregates aggregates, long position1, long limit1, long position2, long limit2, long position3, long limit3) { if (line == 0) {
while (position1 <= limit1 && position2 <= limit2 && position3 <= limit3) { position += 8;
long word1 = word(position1); continue;
long word2 = word(position2); }
long word3 = word(position3);
long separator1 = separator(word1); return position + (Long.numberOfTrailingZeros(line) >>> 3) + 1;
long separator2 = separator(word2);
long separator3 = separator(word3);
position1 = process(aggregates, position1, word1, separator1);
position2 = process(aggregates, position2, word2, separator2);
position3 = process(aggregates, position3, word3, separator3);
}
while (position1 <= limit1) {
long word1 = word(position1);
long separator1 = separator(word1);
position1 = process(aggregates, position1, word1, separator1);
}
while (position2 <= limit2) {
long word2 = word(position2);
long separator2 = separator(word2);
position2 = process(aggregates, position2, word2, separator2);
}
while (position3 <= limit3) {
long word3 = word(position3);
long separator3 = separator(word3);
position3 = process(aggregates, position3, word3, separator3);
} }
} }
private static long process(Aggregates aggregates, long position, long word, long separator) { private static long find(Aggregates aggregates, Chunk chunk, long word, long separator) {
long end = position; long start = chunk.position;
long hash;
int length;
int hash;
int value;
if (separator != 0) { if (separator != 0) {
length = length(separator);
word = mask(word, separator); word = mask(word, separator);
hash = mix(word); hash = mix(word);
end += length;
long num = word(end); chunk.position += length(separator);
int dot = dot(num);
value = value(num, dot);
end += (dot >> 3) + 3;
long pointer = aggregates.find(word, hash); long pointer = aggregates.find(word, hash);
if (pointer != 0) { if (pointer != 0) {
Aggregates.update(pointer, value); return pointer;
return end;
} }
} }
else { else {
long word0 = word; long word0 = word;
word = word(end + 8); word = word(start + 8);
separator = separator(word); separator = separator(word);
if (separator != 0) { if (separator != 0) {
length = length(separator) + 8;
word = mask(word, separator); word = mask(word, separator);
hash = mix(word ^ word0); hash = mix(word ^ word0);
end += length;
long num = word(end); chunk.position += length(separator) + 8;
int dot = dot(num);
value = value(num, dot);
end += (dot >> 3) + 3;
long pointer = aggregates.find(word0, word, hash); long pointer = aggregates.find(word0, word, hash);
if (pointer != 0) { if (pointer != 0) {
Aggregates.update(pointer, value); return pointer;
return end;
} }
} }
else { else {
length = 16; chunk.position += 16;
long h = word ^ word0; hash = word ^ word0;
while (true) { while (true) {
word = word(end + length); word = word(chunk.position);
separator = separator(word); separator = separator(word);
if (separator == 0) { if (separator == 0) {
length += 8; chunk.position += 8;
h ^= word; hash ^= word;
continue; continue;
} }
length += length(separator);
word = mask(word, separator); word = mask(word, separator);
hash = mix(h ^ word); hash = mix(hash ^ word);
end += length; chunk.position += length(separator);
long num = word(end);
int dot = dot(num);
value = value(num, dot);
end += (dot >> 3) + 3;
break; break;
} }
} }
} }
long pointer = aggregates.put(position, word, length, hash); long length = chunk.position - start;
Aggregates.update(pointer, value); return aggregates.put(start, word, length, hash);
return end; }
private static long value(Chunk chunk) {
long num = word(chunk.position);
long dot = dot(num);
chunk.position += (dot >> 3) + 3;
return value(num, dot);
} }
private static long separator(long word) { private static long separator(long word) {
@ -516,28 +524,42 @@ public class CalculateAverage_artsiomkorzun {
return word & mask; return word & mask;
} }
private static int length(long separator) { private static long length(long separator) {
return (Long.numberOfTrailingZeros(separator) >>> 3) + 1; return (Long.numberOfTrailingZeros(separator) >>> 3) + 1;
} }
private static int mix(long x) { private static long mix(long x) {
long h = x * -7046029254386353131L; long h = x * -7046029254386353131L;
h ^= h >>> 35; h ^= h >>> 35;
return (int) h; return h;
// h ^= h >>> 32; // h ^= h >>> 32;
// return (int) (h ^ h >>> 16); // return (int) (h ^ h >>> 16);
} }
private static int dot(long num) { private static long dot(long num) {
return Long.numberOfTrailingZeros(~num & DOT_BITS); return Long.numberOfTrailingZeros(~num & DOT_BITS);
} }
private static int value(long w, int dot) { private static long value(long w, long dot) {
long signed = (~w << 59) >> 63; long signed = (~w << 59) >> 63;
long mask = ~(signed & 0xFF); long mask = ~(signed & 0xFF);
long digits = ((w & mask) << (28 - dot)) & 0x0F000F0F00L; long digits = ((w & mask) << (28 - dot)) & 0x0F000F0F00L;
long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
return (int) ((abs ^ signed) - signed); return (abs ^ signed) - signed;
} }
} }
}
private static class Chunk {
final long limit;
long position;
public Chunk(long position, long limit) {
this.position = position;
this.limit = limit;
}
boolean has() {
return position < limit;
}
}
}