Native build, less memory acess, improved hash mixing (#449)
This commit is contained in:
parent
576291611d
commit
1804fc5b5f
@ -15,5 +15,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
JAVA_OPTS="--enable-preview"
|
if [ -f target/CalculateAverage_abeobk_image ]; then
|
||||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_abeobk
|
echo "Picking up existing native image 'target/CalculateAverage_abeobk_image', delete the file to select JVM mode." 1>&2
|
||||||
|
target/CalculateAverage_abeobk_image
|
||||||
|
else
|
||||||
|
JAVA_OPTS="--enable-preview"
|
||||||
|
echo "Chosing to run the app in JVM mode as no native image was found, use prepare_abeobk.sh to generate." 1>&2
|
||||||
|
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_abeobk
|
||||||
|
fi
|
||||||
|
|
||||||
|
25
prepare_abeobk.sh
Executable file
25
prepare_abeobk.sh
Executable file
@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Copyright 2023 The original authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
source "$HOME/.sdkman/bin/sdkman-init.sh"
|
||||||
|
sdk use java 21.0.1-graal 1>&2
|
||||||
|
|
||||||
|
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
|
||||||
|
if [ ! -f target/CalculateAverage_abeobk_image ]; then
|
||||||
|
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview"
|
||||||
|
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_abeobk_image dev.morling.onebrc.CalculateAverage_abeobk
|
||||||
|
fi
|
@ -24,11 +24,12 @@ import java.nio.channels.FileChannel.MapMode;
|
|||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.StandardOpenOption;
|
import java.nio.file.StandardOpenOption;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
import sun.misc.Unsafe;
|
import sun.misc.Unsafe;
|
||||||
|
|
||||||
public class CalculateAverage_abeobk {
|
public class CalculateAverage_abeobk {
|
||||||
private static final boolean SHOW_COLLISIONS = false;
|
private static final boolean SHOW_ANALYSIS = false;
|
||||||
|
|
||||||
private static final String FILE = "./measurements.txt";
|
private static final String FILE = "./measurements.txt";
|
||||||
private static final int BUCKET_SIZE = 1 << 16;
|
private static final int BUCKET_SIZE = 1 << 16;
|
||||||
@ -99,13 +100,13 @@ public class CalculateAverage_abeobk {
|
|||||||
boolean contentEquals(long other_addr, long other_tail) {
|
boolean contentEquals(long other_addr, long other_tail) {
|
||||||
if (tail != other_tail) // compare tail & length at the same time
|
if (tail != other_tail) // compare tail & length at the same time
|
||||||
return false;
|
return false;
|
||||||
long my_addr = addr;
|
// this is faster than comparision if key is short
|
||||||
int nl = (int) (tail >> 59);
|
long xsum = 0;
|
||||||
for (int i = 0; i < nl; i++, my_addr += 8, other_addr += 8) {
|
int n = ((int) (tail >>> 56)) & 0xF8;
|
||||||
if (UNSAFE.getLong(my_addr) != UNSAFE.getLong(other_addr))
|
for (int i = 0; i < n; i += 8) {
|
||||||
return false;
|
xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i));
|
||||||
}
|
}
|
||||||
return true;
|
return xsum == 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,6 +124,7 @@ public class CalculateAverage_abeobk {
|
|||||||
return ptrs;
|
return ptrs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// idea from royvanrijn
|
||||||
static final long getSemiPosCode(final long word) {
|
static final long getSemiPosCode(final long word) {
|
||||||
long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;;
|
long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;;
|
||||||
return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L);
|
return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L);
|
||||||
@ -133,17 +135,164 @@ public class CalculateAverage_abeobk {
|
|||||||
// zero collision on test data
|
// zero collision on test data
|
||||||
static final int xxh32(long hash) {
|
static final int xxh32(long hash) {
|
||||||
final int p1 = 0x85EBCA77; // prime
|
final int p1 = 0x85EBCA77; // prime
|
||||||
final int p2 = 0xC2B2AE3D; // prime
|
final int p2 = 0x165667B1; // prime
|
||||||
int low = (int) hash;
|
int low = (int) hash;
|
||||||
int high = (int) (hash >>> 32);
|
int high = (int) (hash >>> 31);
|
||||||
low ^= low >> 15;
|
int h = low + high;
|
||||||
low *= p1;
|
h ^= h >> 15;
|
||||||
high ^= high >> 13;
|
h *= p1;
|
||||||
high *= p2;
|
h ^= h >> 13;
|
||||||
var h = low ^ high;
|
h *= p2;
|
||||||
|
h ^= h >> 11;
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// great idea from merykitty (Quan Anh Mai)
|
||||||
|
static final int parseNum(long num_word, int dot_pos) {
|
||||||
|
int shift = 28 - dot_pos;
|
||||||
|
long signed = (~num_word << 59) >> 63;
|
||||||
|
long dsmask = ~(signed & 0xFF);
|
||||||
|
long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
|
||||||
|
long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
|
||||||
|
return (int) ((abs_val ^ signed) - signed);
|
||||||
|
}
|
||||||
|
|
||||||
|
// optimize for contest
|
||||||
|
// save as much slow memory access as possible
|
||||||
|
// about 50% key < 8chars, 25% key bettween 8-10 chars
|
||||||
|
// keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
|
||||||
|
static final Node[] parse(int thread_id, long start, long end, int[] cls) {
|
||||||
|
long addr = start;
|
||||||
|
var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
|
||||||
|
// parse loop
|
||||||
|
while (addr < end) {
|
||||||
|
long row_addr = addr;
|
||||||
|
long tail = 0;
|
||||||
|
long hash = 0;
|
||||||
|
int val = 0;
|
||||||
|
int bucket = 0;
|
||||||
|
|
||||||
|
long word = UNSAFE.getLong(addr);
|
||||||
|
long semipos_code = getSemiPosCode(word);
|
||||||
|
|
||||||
|
// about 50% chance key < 8 chars
|
||||||
|
if (semipos_code != 0) {
|
||||||
|
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
|
||||||
|
addr += semi_pos;
|
||||||
|
tail = (word & HASH_MASKS[semi_pos]);
|
||||||
|
bucket = xxh32(tail) & BUCKET_MASK;
|
||||||
|
long keylen = (addr - row_addr);
|
||||||
|
tail |= (keylen << 56);
|
||||||
|
long num_word = UNSAFE.getLong(++addr);
|
||||||
|
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
||||||
|
val = parseNum(num_word, dot_pos);
|
||||||
|
addr += (dot_pos >>> 3) + 3;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
var node = map[bucket];
|
||||||
|
if (node == null) {
|
||||||
|
map[bucket] = new Node(row_addr, tail, val);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (node.tail == tail) {
|
||||||
|
node.add(val);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
bucket++;
|
||||||
|
if (SHOW_ANALYSIS)
|
||||||
|
cls[thread_id]++;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
hash ^= word;
|
||||||
|
addr += 8;
|
||||||
|
word = UNSAFE.getLong(addr);
|
||||||
|
semipos_code = getSemiPosCode(word);
|
||||||
|
// frist byte semicolon ~13%
|
||||||
|
if (semipos_code == 0x80) {
|
||||||
|
bucket = xxh32(hash) & BUCKET_MASK;
|
||||||
|
tail = 8L << 56;
|
||||||
|
long num_word = word >>> 8;
|
||||||
|
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
||||||
|
val = parseNum(num_word, dot_pos);
|
||||||
|
addr += (dot_pos >>> 3) + 4;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
var node = map[bucket];
|
||||||
|
if (node == null) {
|
||||||
|
map[bucket] = new Node(row_addr, tail, val);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (UNSAFE.getLong(node.addr) == UNSAFE.getLong(row_addr)) {
|
||||||
|
node.add(val);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
bucket++;
|
||||||
|
if (SHOW_ANALYSIS)
|
||||||
|
cls[thread_id]++;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (semipos_code == 0) {
|
||||||
|
hash ^= word;
|
||||||
|
addr += 8;
|
||||||
|
word = UNSAFE.getLong(addr);
|
||||||
|
semipos_code = getSemiPosCode(word);
|
||||||
|
}
|
||||||
|
|
||||||
|
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
|
||||||
|
addr += semi_pos;
|
||||||
|
tail = (word & HASH_MASKS[semi_pos]);
|
||||||
|
hash ^= tail;
|
||||||
|
bucket = xxh32(hash) & BUCKET_MASK;
|
||||||
|
long keylen = (addr - row_addr);
|
||||||
|
tail |= (keylen << 56);
|
||||||
|
|
||||||
|
++addr;
|
||||||
|
long num_word = UNSAFE.getLong(addr);
|
||||||
|
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
||||||
|
val = parseNum(num_word, dot_pos);
|
||||||
|
addr += (dot_pos >>> 3) + 3;
|
||||||
|
|
||||||
|
if (keylen < 16) {
|
||||||
|
while (true) {
|
||||||
|
var node = map[bucket];
|
||||||
|
if (node == null) {
|
||||||
|
map[bucket] = new Node(row_addr, tail, val);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (node.tail == tail && (UNSAFE.getLong(node.addr) == UNSAFE.getLong(row_addr))) {
|
||||||
|
node.add(val);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
bucket++;
|
||||||
|
if (SHOW_ANALYSIS)
|
||||||
|
cls[thread_id]++;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// longer key
|
||||||
|
while (true) {
|
||||||
|
var node = map[bucket];
|
||||||
|
if (node == null) {
|
||||||
|
map[bucket] = new Node(row_addr, tail, val);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (node.contentEquals(row_addr, tail)) {
|
||||||
|
node.add(val);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
bucket++;
|
||||||
|
if (SHOW_ANALYSIS)
|
||||||
|
cls[thread_id]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws InterruptedException, IOException {
|
public static void main(String[] args) throws InterruptedException, IOException {
|
||||||
try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
|
try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
|
||||||
long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
|
long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
|
||||||
@ -158,71 +307,14 @@ public class CalculateAverage_abeobk {
|
|||||||
var threads = new Thread[cpu_cnt];
|
var threads = new Thread[cpu_cnt];
|
||||||
var maps = new Node[cpu_cnt][];
|
var maps = new Node[cpu_cnt][];
|
||||||
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
|
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
|
||||||
int[] cls = new int[cpu_cnt];
|
|
||||||
|
int[] cls = new int[cpu_cnt]; // collision
|
||||||
|
int[] lenhist = new int[64]; // length histogram
|
||||||
|
|
||||||
for (int i = 0; i < cpu_cnt; i++) {
|
for (int i = 0; i < cpu_cnt; i++) {
|
||||||
int thread_id = i;
|
int thread_id = i;
|
||||||
long start = ptrs[i];
|
(threads[thread_id] = new Thread(() -> {
|
||||||
long end = ptrs[i + 1];
|
maps[thread_id] = parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1], cls);
|
||||||
maps[i] = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
|
|
||||||
|
|
||||||
(threads[i] = new Thread(() -> {
|
|
||||||
long addr = start;
|
|
||||||
var map = maps[thread_id];
|
|
||||||
// parse loop
|
|
||||||
while (addr < end) {
|
|
||||||
long hash = 0;
|
|
||||||
long word = 0;
|
|
||||||
long row_addr = addr;
|
|
||||||
int semi_pos = 8;
|
|
||||||
word = UNSAFE.getLong(addr);
|
|
||||||
long semipos_code = getSemiPosCode(word);
|
|
||||||
|
|
||||||
while (semipos_code == 0) {
|
|
||||||
hash ^= word;
|
|
||||||
addr += 8;
|
|
||||||
word = UNSAFE.getLong(addr);
|
|
||||||
semipos_code = getSemiPosCode(word);
|
|
||||||
}
|
|
||||||
|
|
||||||
semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
|
|
||||||
long tail = word & HASH_MASKS[semi_pos];
|
|
||||||
hash ^= tail;
|
|
||||||
addr += semi_pos;
|
|
||||||
|
|
||||||
int hash32 = xxh32(hash);
|
|
||||||
long keylen = (addr - row_addr);
|
|
||||||
tail = tail | (keylen << 56);
|
|
||||||
|
|
||||||
addr++;
|
|
||||||
|
|
||||||
// great idea from merykitty (Quan Anh Mai)
|
|
||||||
long num_word = UNSAFE.getLong(addr);
|
|
||||||
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
|
|
||||||
addr += (dot_pos >>> 3) + 3;
|
|
||||||
int shift = 28 - dot_pos;
|
|
||||||
long signed = (~num_word << 59) >> 63;
|
|
||||||
long dsmask = ~(signed & 0xFF);
|
|
||||||
long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
|
|
||||||
long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
|
|
||||||
int val = (int) ((abs_val ^ signed) - signed);
|
|
||||||
|
|
||||||
int bucket = (hash32 & BUCKET_MASK);
|
|
||||||
while (true) {
|
|
||||||
var node = map[bucket];
|
|
||||||
if (node == null) {
|
|
||||||
map[bucket] = new Node(row_addr, tail, val);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (node.contentEquals(row_addr, tail)) {
|
|
||||||
node.add(val);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
bucket++;
|
|
||||||
if (SHOW_COLLISIONS)
|
|
||||||
cls[thread_id]++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})).start();
|
})).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,7 +322,7 @@ public class CalculateAverage_abeobk {
|
|||||||
for (var thread : threads)
|
for (var thread : threads)
|
||||||
thread.join();
|
thread.join();
|
||||||
|
|
||||||
if (SHOW_COLLISIONS) {
|
if (SHOW_ANALYSIS) {
|
||||||
for (int i = 0; i < cpu_cnt; i++) {
|
for (int i = 0; i < cpu_cnt; i++) {
|
||||||
System.out.println("thread-" + i + " collision = " + cls[i]);
|
System.out.println("thread-" + i + " collision = " + cls[i]);
|
||||||
}
|
}
|
||||||
@ -242,13 +334,22 @@ public class CalculateAverage_abeobk {
|
|||||||
for (var node : map) {
|
for (var node : map) {
|
||||||
if (node == null)
|
if (node == null)
|
||||||
continue;
|
continue;
|
||||||
|
if (SHOW_ANALYSIS) {
|
||||||
|
int kl = (int) (node.tail >>> 56) & (lenhist.length - 1);
|
||||||
|
lenhist[kl] += node.count;
|
||||||
|
}
|
||||||
var stat = ms.putIfAbsent(node.key(), node);
|
var stat = ms.putIfAbsent(node.key(), node);
|
||||||
if (stat != null)
|
if (stat != null)
|
||||||
stat.merge(node);
|
stat.merge(node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!SHOW_COLLISIONS)
|
if (SHOW_ANALYSIS) {
|
||||||
|
System.out.println("total=" + Arrays.stream(lenhist).sum());
|
||||||
|
System.out.println("length_histogram = "
|
||||||
|
+ Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray()));
|
||||||
|
}
|
||||||
|
else
|
||||||
System.out.println(ms);
|
System.out.println(ms);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user