use thomaswue trick, use parallelism, slightly faster (#560)

This commit is contained in:
Van Phu DO 2024-01-24 00:41:25 +09:00 committed by GitHub
parent 8bae1b8781
commit 98a8279669
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 96 additions and 63 deletions

View File

@ -16,10 +16,10 @@
# #
source "$HOME/.sdkman/bin/sdkman-init.sh" source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.1-graal 1>&2 sdk use java 21.0.2-graal 1>&2
# ./mvnw clean verify removes target/ and will re-trigger native image creation. # ./mvnw clean verify removes target/ and will re-trigger native image creation.
if [ ! -f target/CalculateAverage_abeobk_image ]; then if [ ! -f target/CalculateAverage_abeobk_image ]; then
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk" NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk
fi fi

View File

@ -24,8 +24,12 @@ import java.nio.channels.FileChannel.MapMode;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.StandardOpenOption; import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.stream.IntStream;
import sun.misc.Unsafe; import sun.misc.Unsafe;
public class CalculateAverage_abeobk { public class CalculateAverage_abeobk {
@ -66,22 +70,23 @@ public class CalculateAverage_abeobk {
long addr; long addr;
long word0; long word0;
long tail; long tail;
int keylen;
int min, max;
int count;
long sum; long sum;
int count;
short min, max;
int keylen;
String key;
String key() { void calcKey() {
byte[] sbuf = new byte[MAX_STR_LEN]; byte[] sbuf = new byte[MAX_STR_LEN];
UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen); UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
return new String(sbuf, 0, keylen, StandardCharsets.UTF_8); key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
} }
public String toString() { public String toString() {
return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1); return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1);
} }
Node(long a, long t, int val, int kl) { Node(long a, long t, short val, int kl) {
addr = a; addr = a;
tail = t; tail = t;
keylen = kl; keylen = kl;
@ -89,12 +94,16 @@ public class CalculateAverage_abeobk {
count = 1; count = 1;
} }
Node(long a, long t, int val, int kl, long w0) { Node(long a, long w0, long t, short val, int kl) {
this(a, t, val, kl); addr = a;
word0 = w0; word0 = w0;
tail = t;
keylen = kl;
sum = min = max = val;
count = 1;
} }
void add(int val) { void add(short val) {
sum += val; sum += val;
count++; count++;
if (val >= max) { if (val >= max) {
@ -107,19 +116,23 @@ public class CalculateAverage_abeobk {
} }
void merge(Node other) { void merge(Node other) {
min = Math.min(min, other.min);
max = Math.max(max, other.max);
sum += other.sum; sum += other.sum;
count += other.count; count += other.count;
if (other.max > max) {
max = other.max;
}
if (other.min < min) {
min = other.min;
}
} }
boolean contentEquals(long other_addr, long other_tail) { boolean contentEquals(long other_addr, long other_word0, long other_tail) {
if (tail != other_tail) if (tail != other_tail || word0 != other_word0)
return false; return false;
// this is faster than comparision if key is short // this is faster than comparision if key is short
long xsum = 0; long xsum = 0;
int n = keylen & 0xF8; int n = keylen & 0xF8;
for (int i = 0; i < n; i += 8) { for (int i = 8; i < n; i += 8) {
xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i)); xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i));
} }
return xsum == 0; return xsum == 0;
@ -156,29 +169,27 @@ public class CalculateAverage_abeobk {
} }
// great idea from merykitty (Quan Anh Mai) // great idea from merykitty (Quan Anh Mai)
static final int parseNum(long num_word, int dot_pos) { static final short parseNum(long num_word, int dot_pos) {
int shift = 28 - dot_pos; int shift = 28 - dot_pos;
long signed = (~num_word << 59) >> 63; long signed = (~num_word << 59) >> 63;
long dsmask = ~(signed & 0xFF); long dsmask = ~(signed & 0xFF);
long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L; long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF; long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
return (int) ((abs_val ^ signed) - signed); return (short) ((abs_val ^ signed) - signed);
} }
// optimize for contest // optimize for contest
// save as much slow memory access as possible // save as much slow memory access as possible
// about 50% key < 8chars, 25% key bettween 8-10 chars // about 50% key < 8chars, 25% key bettween 8-10 chars
// keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
static final Node[] parse(int thread_id, long start, long end, int[] cls) { static final Node[] parse(int thread_id, long start, long end) {
int cls = 0;
long addr = start; long addr = start;
var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
// parse loop // parse loop
while (addr < end) { while (addr < end) {
long row_addr = addr; long row_addr = addr;
long tail = 0;
long hash = 0; long hash = 0;
int val = 0;
int bucket = 0;
long word0 = UNSAFE.getLong(addr); long word0 = UNSAFE.getLong(addr);
long semipos_code = getSemiPosCode(word0); long semipos_code = getSemiPosCode(word0);
@ -191,9 +202,9 @@ public class CalculateAverage_abeobk {
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 3; addr += (dot_pos >>> 3) + 3;
tail = (word0 & HASH_MASKS[semi_pos]); long tail = (word0 & HASH_MASKS[semi_pos]);
bucket = xxh32(tail) & BUCKET_MASK; int bucket = xxh32(tail) & BUCKET_MASK;
val = parseNum(num_word, dot_pos); short val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; var node = map[bucket];
@ -207,7 +218,7 @@ public class CalculateAverage_abeobk {
} }
bucket++; bucket++;
if (SHOW_ANALYSIS) if (SHOW_ANALYSIS)
cls[thread_id]++; cls++;
} }
continue; continue;
} }
@ -225,15 +236,15 @@ public class CalculateAverage_abeobk {
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 4; addr += (dot_pos >>> 3) + 4;
tail = (word & HASH_MASKS[semi_pos]); long tail = (word & HASH_MASKS[semi_pos]);
hash ^= tail; hash ^= tail;
bucket = xxh32(hash) & BUCKET_MASK; int bucket = xxh32(hash) & BUCKET_MASK;
val = parseNum(num_word, dot_pos); short val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, tail, val, keylen, word0); map[bucket] = new Node(row_addr, word0, tail, val, keylen);
break; break;
} }
if (node.word0 == word0 && node.tail == tail) { if (node.word0 == word0 && node.tail == tail) {
@ -242,7 +253,7 @@ public class CalculateAverage_abeobk {
} }
bucket++; bucket++;
if (SHOW_ANALYSIS) if (SHOW_ANALYSIS)
cls[thread_id]++; cls++;
} }
continue; continue;
} }
@ -261,30 +272,55 @@ public class CalculateAverage_abeobk {
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000); int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 4; addr += (dot_pos >>> 3) + 4;
tail = (word & HASH_MASKS[semi_pos]); long tail = (word & HASH_MASKS[semi_pos]);
hash ^= tail; hash ^= tail;
bucket = xxh32(hash) & BUCKET_MASK; int bucket = xxh32(hash) & BUCKET_MASK;
val = parseNum(num_word, dot_pos); short val = parseNum(num_word, dot_pos);
while (true) { while (true) {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, tail, val, keylen); map[bucket] = new Node(row_addr, word0, tail, val, keylen);
break; break;
} }
if (node.contentEquals(row_addr, tail)) { if (node.contentEquals(row_addr, word0, tail)) {
node.add(val); node.add(val);
break; break;
} }
bucket++; bucket++;
if (SHOW_ANALYSIS) if (SHOW_ANALYSIS)
cls[thread_id]++; cls++;
} }
} }
if (SHOW_ANALYSIS) {
debug("Thread %d collision = %d", thread_id, cls);
}
return map; return map;
} }
// thomaswue trick
private static void spawnWorker() throws IOException {
ProcessHandle.Info info = ProcessHandle.current().info();
ArrayList<String> workerCommand = new ArrayList<>();
info.command().ifPresent(workerCommand::add);
info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args)));
workerCommand.add("--worker");
new ProcessBuilder()
.command(workerCommand)
.inheritIO()
.redirectOutput(ProcessBuilder.Redirect.PIPE)
.start()
.getInputStream()
.transferTo(System.out);
}
public static void main(String[] args) throws InterruptedException, IOException { public static void main(String[] args) throws InterruptedException, IOException {
// thomaswue trick
if (args.length == 0 || !("--worker".equals(args[0]))) {
spawnWorker();
return;
}
try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address(); long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
long file_size = file.size(); long file_size = file.size();
@ -295,51 +331,48 @@ public class CalculateAverage_abeobk {
long chunk_size = Math.ceilDiv(file_size, cpu_cnt); long chunk_size = Math.ceilDiv(file_size, cpu_cnt);
// processing // processing
var threads = new Thread[cpu_cnt];
var maps = new Node[cpu_cnt][];
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
int[] cls = new int[cpu_cnt]; // collision TreeMap<String, Node> ms = new TreeMap<>();
int[] lenhist = new int[64]; // length histogram int[] lenhist = new int[64]; // length histogram
for (int i = 0; i < cpu_cnt; i++) { List<List<Node>> maps = IntStream.range(0, cpu_cnt)
int thread_id = i; .mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1]))
(threads[thread_id] = new Thread(() -> { .map(map -> {
maps[thread_id] = parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1], cls); List<Node> nodes = new ArrayList<>();
})).start(); for (var node : map) {
} if (node == null)
continue;
node.calcKey();
nodes.add(node);
}
return nodes;
})
.parallel()
.toList();
// join all for (var nodes : maps) {
for (var thread : threads) for (var node : nodes) {
thread.join();
// collect results
TreeMap<String, Node> ms = new TreeMap<>();
for (var map : maps) {
for (var node : map) {
if (node == null)
continue;
if (SHOW_ANALYSIS) { if (SHOW_ANALYSIS) {
int kl = node.keylen & (lenhist.length - 1); int kl = node.keylen & (lenhist.length - 1);
lenhist[kl] += node.count; lenhist[kl] += node.count;
} }
var stat = ms.putIfAbsent(node.key(), node); var stat = ms.putIfAbsent(node.key, node);
if (stat != null) if (stat != null)
stat.merge(node); stat.merge(node);
} }
} }
if (SHOW_ANALYSIS) { if (SHOW_ANALYSIS) {
debug("Collision stat: ");
for (int i = 0; i < cpu_cnt; i++) {
debug("thread-" + i + " collision = " + cls[i]);
}
debug("Total = " + Arrays.stream(lenhist).sum()); debug("Total = " + Arrays.stream(lenhist).sum());
debug("Length_histogram = " debug("Length_histogram = "
+ Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray())); + Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray()));
return;
} }
else
System.out.println(ms); // print result
System.out.println(ms);
System.out.close();
} }
} }
} }