use thomaswue trick, use parallelism, slightly faster (#560)
This commit is contained in:
parent
8bae1b8781
commit
98a8279669
@ -16,10 +16,10 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
source "$HOME/.sdkman/bin/sdkman-init.sh"
|
source "$HOME/.sdkman/bin/sdkman-init.sh"
|
||||||
sdk use java 21.0.1-graal 1>&2
|
sdk use java 21.0.2-graal 1>&2
|
||||||
|
|
||||||
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
|
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
|
||||||
if [ ! -f target/CalculateAverage_abeobk_image ]; then
|
if [ ! -f target/CalculateAverage_abeobk_image ]; then
|
||||||
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
|
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=128m -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_abeobk"
|
||||||
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk
|
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk
|
||||||
fi
|
fi
|
||||||
|
@ -24,8 +24,12 @@ import java.nio.channels.FileChannel.MapMode;
|
|||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.StandardOpenOption;
|
import java.nio.file.StandardOpenOption;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
import sun.misc.Unsafe;
|
import sun.misc.Unsafe;
|
||||||
|
|
||||||
public class CalculateAverage_abeobk {
|
public class CalculateAverage_abeobk {
|
||||||
@ -66,22 +70,23 @@ public class CalculateAverage_abeobk {
|
|||||||
long addr;
|
long addr;
|
||||||
long word0;
|
long word0;
|
||||||
long tail;
|
long tail;
|
||||||
int keylen;
|
|
||||||
int min, max;
|
|
||||||
int count;
|
|
||||||
long sum;
|
long sum;
|
||||||
|
int count;
|
||||||
|
short min, max;
|
||||||
|
int keylen;
|
||||||
|
String key;
|
||||||
|
|
||||||
String key() {
|
void calcKey() {
|
||||||
byte[] sbuf = new byte[MAX_STR_LEN];
|
byte[] sbuf = new byte[MAX_STR_LEN];
|
||||||
UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
|
UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
|
||||||
return new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
|
key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
|
||||||
}
|
}
|
||||||
|
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1);
|
return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Node(long a, long t, int val, int kl) {
|
Node(long a, long t, short val, int kl) {
|
||||||
addr = a;
|
addr = a;
|
||||||
tail = t;
|
tail = t;
|
||||||
keylen = kl;
|
keylen = kl;
|
||||||
@ -89,12 +94,16 @@ public class CalculateAverage_abeobk {
|
|||||||
count = 1;
|
count = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Node(long a, long t, int val, int kl, long w0) {
|
Node(long a, long w0, long t, short val, int kl) {
|
||||||
this(a, t, val, kl);
|
addr = a;
|
||||||
word0 = w0;
|
word0 = w0;
|
||||||
|
tail = t;
|
||||||
|
keylen = kl;
|
||||||
|
sum = min = max = val;
|
||||||
|
count = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
void add(int val) {
|
void add(short val) {
|
||||||
sum += val;
|
sum += val;
|
||||||
count++;
|
count++;
|
||||||
if (val >= max) {
|
if (val >= max) {
|
||||||
@ -107,19 +116,23 @@ public class CalculateAverage_abeobk {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void merge(Node other) {
|
void merge(Node other) {
|
||||||
min = Math.min(min, other.min);
|
|
||||||
max = Math.max(max, other.max);
|
|
||||||
sum += other.sum;
|
sum += other.sum;
|
||||||
count += other.count;
|
count += other.count;
|
||||||
|
if (other.max > max) {
|
||||||
|
max = other.max;
|
||||||
|
}
|
||||||
|
if (other.min < min) {
|
||||||
|
min = other.min;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean contentEquals(long other_addr, long other_tail) {
|
boolean contentEquals(long other_addr, long other_word0, long other_tail) {
|
||||||
if (tail != other_tail)
|
if (tail != other_tail || word0 != other_word0)
|
||||||
return false;
|
return false;
|
||||||
// this is faster than comparision if key is short
|
// this is faster than comparision if key is short
|
||||||
long xsum = 0;
|
long xsum = 0;
|
||||||
int n = keylen & 0xF8;
|
int n = keylen & 0xF8;
|
||||||
for (int i = 0; i < n; i += 8) {
|
for (int i = 8; i < n; i += 8) {
|
||||||
xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i));
|
xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i));
|
||||||
}
|
}
|
||||||
return xsum == 0;
|
return xsum == 0;
|
||||||
@ -156,29 +169,27 @@ public class CalculateAverage_abeobk {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// great idea from merykitty (Quan Anh Mai)
|
// great idea from merykitty (Quan Anh Mai)
|
||||||
static final int parseNum(long num_word, int dot_pos) {
|
static final short parseNum(long num_word, int dot_pos) {
|
||||||
int shift = 28 - dot_pos;
|
int shift = 28 - dot_pos;
|
||||||
long signed = (~num_word << 59) >> 63;
|
long signed = (~num_word << 59) >> 63;
|
||||||
long dsmask = ~(signed & 0xFF);
|
long dsmask = ~(signed & 0xFF);
|
||||||
long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
|
long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
|
||||||
long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
|
long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
|
||||||
return (int) ((abs_val ^ signed) - signed);
|
return (short) ((abs_val ^ signed) - signed);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optimize for contest
|
// optimize for contest
|
||||||
// save as much slow memory access as possible
|
// save as much slow memory access as possible
|
||||||
// about 50% key < 8chars, 25% key bettween 8-10 chars
|
// about 50% key < 8chars, 25% key bettween 8-10 chars
|
||||||
// keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
|
// keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
|
||||||
static final Node[] parse(int thread_id, long start, long end, int[] cls) {
|
static final Node[] parse(int thread_id, long start, long end) {
|
||||||
|
int cls = 0;
|
||||||
long addr = start;
|
long addr = start;
|
||||||
var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
|
var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
|
||||||
// parse loop
|
// parse loop
|
||||||
while (addr < end) {
|
while (addr < end) {
|
||||||
long row_addr = addr;
|
long row_addr = addr;
|
||||||
long tail = 0;
|
|
||||||
long hash = 0;
|
long hash = 0;
|
||||||
int val = 0;
|
|
||||||
int bucket = 0;
|
|
||||||
|
|
||||||
long word0 = UNSAFE.getLong(addr);
|
long word0 = UNSAFE.getLong(addr);
|
||||||
long semipos_code = getSemiPosCode(word0);
|
long semipos_code = getSemiPosCode(word0);
|
||||||
@ -191,9 +202,9 @@ public class CalculateAverage_abeobk {
|
|||||||
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
||||||
addr += (dot_pos >>> 3) + 3;
|
addr += (dot_pos >>> 3) + 3;
|
||||||
|
|
||||||
tail = (word0 & HASH_MASKS[semi_pos]);
|
long tail = (word0 & HASH_MASKS[semi_pos]);
|
||||||
bucket = xxh32(tail) & BUCKET_MASK;
|
int bucket = xxh32(tail) & BUCKET_MASK;
|
||||||
val = parseNum(num_word, dot_pos);
|
short val = parseNum(num_word, dot_pos);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
var node = map[bucket];
|
var node = map[bucket];
|
||||||
@ -207,7 +218,7 @@ public class CalculateAverage_abeobk {
|
|||||||
}
|
}
|
||||||
bucket++;
|
bucket++;
|
||||||
if (SHOW_ANALYSIS)
|
if (SHOW_ANALYSIS)
|
||||||
cls[thread_id]++;
|
cls++;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -225,15 +236,15 @@ public class CalculateAverage_abeobk {
|
|||||||
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
||||||
addr += (dot_pos >>> 3) + 4;
|
addr += (dot_pos >>> 3) + 4;
|
||||||
|
|
||||||
tail = (word & HASH_MASKS[semi_pos]);
|
long tail = (word & HASH_MASKS[semi_pos]);
|
||||||
hash ^= tail;
|
hash ^= tail;
|
||||||
bucket = xxh32(hash) & BUCKET_MASK;
|
int bucket = xxh32(hash) & BUCKET_MASK;
|
||||||
val = parseNum(num_word, dot_pos);
|
short val = parseNum(num_word, dot_pos);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
var node = map[bucket];
|
var node = map[bucket];
|
||||||
if (node == null) {
|
if (node == null) {
|
||||||
map[bucket] = new Node(row_addr, tail, val, keylen, word0);
|
map[bucket] = new Node(row_addr, word0, tail, val, keylen);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (node.word0 == word0 && node.tail == tail) {
|
if (node.word0 == word0 && node.tail == tail) {
|
||||||
@ -242,7 +253,7 @@ public class CalculateAverage_abeobk {
|
|||||||
}
|
}
|
||||||
bucket++;
|
bucket++;
|
||||||
if (SHOW_ANALYSIS)
|
if (SHOW_ANALYSIS)
|
||||||
cls[thread_id]++;
|
cls++;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -261,30 +272,55 @@ public class CalculateAverage_abeobk {
|
|||||||
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
||||||
addr += (dot_pos >>> 3) + 4;
|
addr += (dot_pos >>> 3) + 4;
|
||||||
|
|
||||||
tail = (word & HASH_MASKS[semi_pos]);
|
long tail = (word & HASH_MASKS[semi_pos]);
|
||||||
hash ^= tail;
|
hash ^= tail;
|
||||||
bucket = xxh32(hash) & BUCKET_MASK;
|
int bucket = xxh32(hash) & BUCKET_MASK;
|
||||||
val = parseNum(num_word, dot_pos);
|
short val = parseNum(num_word, dot_pos);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
var node = map[bucket];
|
var node = map[bucket];
|
||||||
if (node == null) {
|
if (node == null) {
|
||||||
map[bucket] = new Node(row_addr, tail, val, keylen);
|
map[bucket] = new Node(row_addr, word0, tail, val, keylen);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (node.contentEquals(row_addr, tail)) {
|
if (node.contentEquals(row_addr, word0, tail)) {
|
||||||
node.add(val);
|
node.add(val);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
bucket++;
|
bucket++;
|
||||||
if (SHOW_ANALYSIS)
|
if (SHOW_ANALYSIS)
|
||||||
cls[thread_id]++;
|
cls++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (SHOW_ANALYSIS) {
|
||||||
|
debug("Thread %d collision = %d", thread_id, cls);
|
||||||
|
}
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// thomaswue trick
|
||||||
|
private static void spawnWorker() throws IOException {
|
||||||
|
ProcessHandle.Info info = ProcessHandle.current().info();
|
||||||
|
ArrayList<String> workerCommand = new ArrayList<>();
|
||||||
|
info.command().ifPresent(workerCommand::add);
|
||||||
|
info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args)));
|
||||||
|
workerCommand.add("--worker");
|
||||||
|
new ProcessBuilder()
|
||||||
|
.command(workerCommand)
|
||||||
|
.inheritIO()
|
||||||
|
.redirectOutput(ProcessBuilder.Redirect.PIPE)
|
||||||
|
.start()
|
||||||
|
.getInputStream()
|
||||||
|
.transferTo(System.out);
|
||||||
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws InterruptedException, IOException {
|
public static void main(String[] args) throws InterruptedException, IOException {
|
||||||
|
// thomaswue trick
|
||||||
|
if (args.length == 0 || !("--worker".equals(args[0]))) {
|
||||||
|
spawnWorker();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
|
try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
|
||||||
long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
|
long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
|
||||||
long file_size = file.size();
|
long file_size = file.size();
|
||||||
@ -295,51 +331,48 @@ public class CalculateAverage_abeobk {
|
|||||||
long chunk_size = Math.ceilDiv(file_size, cpu_cnt);
|
long chunk_size = Math.ceilDiv(file_size, cpu_cnt);
|
||||||
|
|
||||||
// processing
|
// processing
|
||||||
var threads = new Thread[cpu_cnt];
|
|
||||||
var maps = new Node[cpu_cnt][];
|
|
||||||
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
|
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
|
||||||
|
|
||||||
int[] cls = new int[cpu_cnt]; // collision
|
TreeMap<String, Node> ms = new TreeMap<>();
|
||||||
int[] lenhist = new int[64]; // length histogram
|
int[] lenhist = new int[64]; // length histogram
|
||||||
|
|
||||||
for (int i = 0; i < cpu_cnt; i++) {
|
List<List<Node>> maps = IntStream.range(0, cpu_cnt)
|
||||||
int thread_id = i;
|
.mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1]))
|
||||||
(threads[thread_id] = new Thread(() -> {
|
.map(map -> {
|
||||||
maps[thread_id] = parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1], cls);
|
List<Node> nodes = new ArrayList<>();
|
||||||
})).start();
|
for (var node : map) {
|
||||||
}
|
if (node == null)
|
||||||
|
continue;
|
||||||
|
node.calcKey();
|
||||||
|
nodes.add(node);
|
||||||
|
}
|
||||||
|
return nodes;
|
||||||
|
})
|
||||||
|
.parallel()
|
||||||
|
.toList();
|
||||||
|
|
||||||
// join all
|
for (var nodes : maps) {
|
||||||
for (var thread : threads)
|
for (var node : nodes) {
|
||||||
thread.join();
|
|
||||||
|
|
||||||
// collect results
|
|
||||||
TreeMap<String, Node> ms = new TreeMap<>();
|
|
||||||
for (var map : maps) {
|
|
||||||
for (var node : map) {
|
|
||||||
if (node == null)
|
|
||||||
continue;
|
|
||||||
if (SHOW_ANALYSIS) {
|
if (SHOW_ANALYSIS) {
|
||||||
int kl = node.keylen & (lenhist.length - 1);
|
int kl = node.keylen & (lenhist.length - 1);
|
||||||
lenhist[kl] += node.count;
|
lenhist[kl] += node.count;
|
||||||
}
|
}
|
||||||
var stat = ms.putIfAbsent(node.key(), node);
|
var stat = ms.putIfAbsent(node.key, node);
|
||||||
if (stat != null)
|
if (stat != null)
|
||||||
stat.merge(node);
|
stat.merge(node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (SHOW_ANALYSIS) {
|
if (SHOW_ANALYSIS) {
|
||||||
debug("Collision stat: ");
|
|
||||||
for (int i = 0; i < cpu_cnt; i++) {
|
|
||||||
debug("thread-" + i + " collision = " + cls[i]);
|
|
||||||
}
|
|
||||||
debug("Total = " + Arrays.stream(lenhist).sum());
|
debug("Total = " + Arrays.stream(lenhist).sum());
|
||||||
debug("Length_histogram = "
|
debug("Length_histogram = "
|
||||||
+ Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray()));
|
+ Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray()));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
else
|
|
||||||
System.out.println(ms);
|
// print result
|
||||||
|
System.out.println(ms);
|
||||||
|
System.out.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user