CalculateAverage_stephenvonworley submission (#677)

* first release

* change constants to names

---------

Co-authored-by: Stephen Von Worley <von@von.io>
This commit is contained in:
Stephen Von Worley 2024-01-31 12:49:19 -08:00 committed by GitHub
parent a8823a1f93
commit a533019ad4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 580 additions and 0 deletions

View File

@ -0,0 +1,25 @@
#!/bin/sh
#
# Copyright 2023 The original authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if [ -f target/CalculateAverage_stephenvonworley_image ]; then
target/CalculateAverage_stephenvonworley_image
else
JAVA_OPTS="--enable-preview"
echo "Chosing to run the app in JVM mode as no native image was found, use prepare_stephenvonworley.sh to generate." 1>&2
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_stephenvonworley
fi

25
prepare_stephenvonworley.sh Executable file
View File

@ -0,0 +1,25 @@
#!/bin/bash
#
# Copyright 2023 The original authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.2-graal 1>&2
# ./mvnw clean verify removes target/ and will re-trigger native image creation.
if [ ! -f target/CalculateAverage_stephenvonworley_image ]; then
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -H:TuneInlinerExploration=1 -march=native --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_stephenvonworley"
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_stephenvonworley_image dev.morling.onebrc.CalculateAverage_stephenvonworley
fi

View File

@ -0,0 +1,530 @@
/*
* Copyright 2023 The original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.morling.onebrc;
import java.io.*;
import java.lang.foreign.*;
import java.lang.reflect.Field;
import java.nio.*;
import java.nio.channels.*;
import java.nio.file.*;
import java.nio.charset.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.stream.*;
import sun.misc.Unsafe;
/*
* Stephen Von Worley's (von@von.io) entry to Gunnar Morling's "One Billion Row Challenge":
* https://www.morling.dev/blog/one-billion-row-challenge/
*
* To compute the desired result, this program:
* 1. Memory maps the input file.
* 2. Partitions the file into a queue of Chunks, which delimit sections of the file.
* 3. Spawns one thread per processor. Each thread:
* a. Allocates a Table, which will accumulate names and tallies (min/max/total/count).
* b. Get a Chunk from the queue.
* c. Processes the Chunk using a parser that reads the Chunk simultaneously at three
* different, evenly-spaced locations, using heavily-optimized scalar code.
* d. Repeats steps b and c until there are no more Chunks.
* 4. Aggregates the resulting Tables into a treemap of names to Tallies.
* 5. Outputs the names and Tallies in ascending name order.
*
* Runs fastest as a natively-compiled, standalone binary, as might be produced by Graal's
* `native-image` utility. Tested with Oracle Graal 21.0.2.
*
* Incorporates code authored by a number of submitters, including Thomas Wue, Quan Anh
* Mai, and others.
*
* Thanks y'all, and Happy Rowing!
* Steve
* von@von.io
* www.von.io
*/
public class CalculateAverage_stephenvonworley {
private static final int NAME_LIMIT = 10000;
private static final long CHUNK_SIZE = 5000000;
private static final long CHUNK_PAD = 200;
private static final long CHUNK_PARSE3_LIMIT = 1000;
private static final long GOLDEN_LONG = 0x9e3779b97f4a7c15L;
private static final long TALLY_BITS = 7;
private static final long TALLY_SIZE = 1L << TALLY_BITS;
private static final long HASH_BITS = 16;
private static final long HASH_MASK = ((1L << HASH_BITS) - 1) << TALLY_BITS;
private static final long TABLE_SIZE = 1L << (HASH_BITS + TALLY_BITS);
private static final long OFFSET_MIN = 0;
private static final long OFFSET_MAX = 2;
private static final long OFFSET_COUNT = 4;
private static final long OFFSET_TOTAL = 8;
private static final long OFFSET_LEN = 16;
private static final long OFFSET_NAME = 17;
private static final Unsafe unsafe;
static {
try {
Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
unsafe = (Unsafe) f.get(null);
}
catch (Exception e) {
throw new RuntimeException("Exception initializing unsafe", e);
}
}
public static void main(String[] args) throws IOException, InterruptedException {
if (!List.of(args).contains("--worker")) {
spawnWorker();
return;
}
MemorySegment in = map("./measurements.txt");
Queue<Chunk> chunks = partition(in);
List<Table> tables = process(chunks, processorCount());
Map<String, Tally> nameToTally = aggregate(tables);
System.out.println(nameToTally);
System.out.close();
}
// credit: "Spawn worker" code by Thomas Wue
private static void spawnWorker() throws IOException {
ProcessHandle.Info info = ProcessHandle.current().info();
ArrayList<String> workerCommand = new ArrayList<>();
info.command().ifPresent(workerCommand::add);
info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args)));
workerCommand.add("--worker");
new ProcessBuilder().command(workerCommand).inheritIO().redirectOutput(ProcessBuilder.Redirect.PIPE)
.start().getInputStream().transferTo(System.out);
}
private static int processorCount() {
return Runtime.getRuntime().availableProcessors();
}
private static MemorySegment map(String path) throws IOException {
FileChannel file = FileChannel.open(Path.of(path), StandardOpenOption.READ);
return file.map(FileChannel.MapMode.READ_ONLY, 0, file.size(), Arena.global());
}
private static MemorySegment allocate(long len) {
return Arena.global().allocate(len, 4096);
}
private static Queue<Chunk> partition(MemorySegment in) throws IOException {
Queue<Chunk> chunks = new ConcurrentLinkedDeque<>();
long address = in.address();
long len = in.byteSize();
long start = address;
while (start < address + len) {
long end = start + CHUNK_SIZE;
if (end >= address + len) {
end = address + len;
}
else {
end = afterNewline(end);
}
Chunk chunk;
if (end + CHUNK_PAD < address + len) {
chunk = new Chunk(start, end);
}
else {
MemorySegment padded = allocate(end - start + CHUNK_PAD);
MemorySegment.copy(in, start - address, padded, 0, end - start);
chunk = new Chunk(padded.address(), padded.address() + (end - start));
}
chunks.offer(chunk);
start = end;
}
return chunks;
}
private static List<Table> process(Queue<Chunk> chunks, int threadCount) throws InterruptedException {
List<Table> tables = Collections.synchronizedList(new ArrayList<>(threadCount));
List<Thread> threads = new ArrayList<>(threadCount);
for (int i = 0; i < threadCount; i++) {
Thread thread = new Thread(() -> {
Table t = new Table();
tables.add(t);
Chunk chunk;
while ((chunk = chunks.poll()) != null) {
parse3(chunk.start(), chunk.end(), t);
}
});
threads.add(thread);
thread.start();
}
for (Thread thread : threads) {
thread.join();
}
return tables;
}
private static Map<String, Tally> aggregate(List<Table> tables) {
Map<String, Tally> nameToTally = new TreeMap<>();
tables.forEach(table -> aggregate(nameToTally, table));
return nameToTally;
}
private static void aggregate(Map<String, Tally> nameToTally, Table table) {
table.process((name, min, max, total, count) -> nameToTally.computeIfAbsent(name, _ -> new Tally()).add(min, max, total, count));
}
private static void parse3(long start, long end, Table table) {
if (end - start < CHUNK_PARSE3_LIMIT) {
parse1(start, end, table);
return;
}
final long tallies = table.tallies;
long part = (end - start) / 3;
long startA = start;
long startB = afterNewline(start + part);
long startC = afterNewline(start + 2 * part);
long endA = startB;
long endB = startC;
long endC = end;
while (true) {
long N = min(
remaining(startA, endA),
remaining(startB, endB),
remaining(startC, endC));
if (N <= 1) {
break;
}
while (N > 0) {
long semicolonA = semicolon(startA);
long semicolonB = semicolon(startB);
long semicolonC = semicolon(startC);
long tallyA = locate(startA, semicolonA, tallies, table);
long tallyB = locate(startB, semicolonB, tallies, table);
long tallyC = locate(startC, semicolonC, tallies, table);
long numberA = number(semicolonA);
tally(tallyA, numberA);
long numberB = number(semicolonB);
tally(tallyB, numberB);
long numberC = number(semicolonC);
tally(tallyC, numberC);
startA = next(semicolonA);
startB = next(semicolonB);
startC = next(semicolonC);
N--;
}
}
parse1(startA, endA, table);
parse1(startB, endB, table);
parse1(startC, endC, table);
}
private static void parse1(long start, long end, Table table) {
final long tallies = table.tallies;
while (start < end) {
long semicolon = semicolon(start);
long tally = locate(start, semicolon, tallies, table);
long number = number(semicolon);
tally(tally, number);
start = next(semicolon);
}
}
private static long remaining(long start, long end) {
return (end - start) >> 7;
}
// credit: Adapted from code by Thomas Wue
private static long semicolon(long start) {
start++;
long word = getLong(start);
long input = word ^ 0x3B3B3B3B3B3B3B3BL;
long tmp = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L;
if (tmp != 0) {
return start + (Long.numberOfTrailingZeros(tmp) >>> 3);
}
while (true) {
start += 8;
long word2 = getLong(start);
long input2 = word2 ^ 0x3B3B3B3B3B3B3B3BL;
long tmp2 = (input2 - 0x0101010101010101L) & ~input2 & 0x8080808080808080L;
if (tmp2 != 0) {
return start + (Long.numberOfTrailingZeros(tmp2) >>> 3);
}
}
}
private static long trim(long value, long remove) {
long shift = remove << 3;
return ((value << shift) >>> shift);
}
// https://softwareengineering.stackexchange.com/questions/402542/where-do-magic-hashing-constants-like-0x9e3779b9-and-0x9e3779b1-come-from
private static long locate(long start, long semicolon, long tallies, Table table) {
long len = semicolon - start;
long word = getLong(start);
if (len <= 8) {
word = trim(word, 8 - len);
long hash = word * GOLDEN_LONG;
long offset = (hash >>> (64 - HASH_BITS)) << TALLY_BITS;
while (true) {
long tally = tallies + offset;
long tlen = getByte(tally + OFFSET_LEN);
long tword = getLong(tally + OFFSET_NAME);
if (len == tlen && word == tword) {
return tally;
}
if (tword == 0) {
init(tally, start, len, table);
return tally;
}
offset = (offset + TALLY_SIZE) & HASH_MASK;
}
}
else {
long word2 = getLong(semicolon - 8);
long hash = (word + word2) * GOLDEN_LONG;
long offset = (hash >>> (64 - HASH_BITS)) << TALLY_BITS;
while (true) {
long tally = tallies + offset;
long tword = getLong(tally + OFFSET_NAME);
if (len <= 16) {
long tlen = getByte(tally + OFFSET_LEN);
long tword2 = getLong(tally + OFFSET_NAME + len - 8);
if (len == tlen && word == tword && word2 == tword2) {
return tally;
}
}
else {
if (match(tally, start, len)) {
return tally;
}
}
if (tword == 0) {
init(tally, start, len, table);
return tally;
}
offset = (offset + TALLY_SIZE) & HASH_MASK;
}
}
}
private static void init(long tally, long start, long len, Table t) {
setShort(tally + OFFSET_MIN, Short.MAX_VALUE);
setShort(tally + OFFSET_MAX, Short.MIN_VALUE);
setByte(tally + OFFSET_LEN, (byte) len);
copyMemory(start, tally + OFFSET_NAME, len);
t.addresses[t.count++] = tally;
}
private static boolean match(long tally, long name, long len) {
if (getByte(tally + OFFSET_LEN) != len) {
return false;
}
long a = name;
long b = tally + OFFSET_NAME;
while (len > 7) {
if (getLong(a) != getLong(b)) {
return false;
}
a += 8;
b += 8;
len -= 8;
}
if (len > 0) {
return (trim(getLong(a), 8 - len) == getLong(b));
}
return true;
}
// credit: Wonderfully-fast number parsing implementation by Quan Anh Mai
private static long number(long semicolon) {
long numberWord = getLong(semicolon + 1);
int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000);
int shift = 28 - decimalSepPos;
// signed is -1 if negative, 0 otherwise
long signed = (~numberWord << 59) >> 63;
long designMask = ~(signed & 0xFF);
// Align the number to a specific position and transform the ascii to digit value
long digits = ((numberWord & designMask) << shift) & 0x0F000F0F00L;
// Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit)
// 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
// 0x000000UU00TTHH00 + 0x00UU00TTHH000000 * 10 + 0xUU00TTHH00000000 * 100
long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
return (absValue ^ signed) - signed;
}
private static void tally(long tally, long number) {
short min = getShort(tally + OFFSET_MIN);
short max = getShort(tally + OFFSET_MAX);
int count = getInt(tally + OFFSET_COUNT);
long total = getLong(tally + OFFSET_TOTAL);
if (number < min) {
setShort(tally + OFFSET_MIN, (short) number);
}
if (number > max) {
setShort(tally + OFFSET_MAX, (short) number);
}
setInt(tally + OFFSET_COUNT, count + 1);
setLong(tally + OFFSET_TOTAL, total + number);
}
private static long next(long semicolon) {
long word = getLong(semicolon);
semicolon += 7;
semicolon -= (~word >>> (24 + 4)) & 1;
semicolon -= (~word >>> (16 + 4 - 1)) & 2;
return semicolon;
}
private static long afterNewline(long start) {
while (getByte(start) != '\n')
start++;
return start + 1;
}
private static long min(long a, long b, long c) {
return Math.min(a, Math.min(b, c));
}
private static byte getByte(long addr) {
return unsafe.getByte(addr);
}
private static short getShort(long addr) {
return unsafe.getShort(addr);
}
private static int getInt(long addr) {
return unsafe.getInt(addr);
}
private static long getLong(long addr) {
return unsafe.getLong(addr);
}
private static void setByte(long addr, byte value) {
unsafe.putByte(addr, value);
}
private static void setShort(long addr, short value) {
unsafe.putShort(addr, value);
}
private static void setInt(long addr, int value) {
unsafe.putInt(addr, value);
}
private static void setLong(long addr, long value) {
unsafe.putLong(addr, value);
}
private static void copyMemory(long srcAddr, long dstAddr, long count) {
unsafe.copyMemory(srcAddr, dstAddr, count);
}
private static record Chunk(long start, long end) {
}
private static class Table {
public final long tallies;
public final long[] addresses;
public int count;
public Table() {
tallies = allocate(TABLE_SIZE).address();
addresses = new long[NAME_LIMIT];
count = 0;
}
public void process(Consumer consumer) {
for (int i = 0; i < count; i++) {
long address = addresses[i];
int len = getByte(address + OFFSET_LEN);
byte[] bytes = new byte[len];
for (int j = 0; j < len; j++) {
bytes[j] = getByte(address + OFFSET_NAME + j);
}
String name = new String(bytes, StandardCharsets.UTF_8);
long min = getShort(address + OFFSET_MIN);
long max = getShort(address + OFFSET_MAX);
long total = getLong(address + OFFSET_TOTAL);
long count = getInt(address + OFFSET_COUNT);
consumer.consume(name, min, max, total, count);
}
}
}
private static interface Consumer {
public void consume(String name, long min, long max, long total, long count);
}
private static class Tally {
private long min;
private long max;
private long total;
private long count;
public Tally() {
this.min = Short.MAX_VALUE;
this.max = Short.MIN_VALUE;
this.total = 0;
this.count = 0;
}
public void add(long addMin, long addMax, long addTotal, long addCount) {
min = Math.min(min, addMin);
max = Math.max(max, addMax);
total += addTotal;
count += addCount;
}
public long getMin() {
return min;
}
public long getMax() {
return max;
}
public long getTotal() {
return total;
}
public long getCount() {
return count;
}
public String toString() {
return String.format("%.1f/%.1f/%.1f",
getMin() / 10.0,
getTotal() / (10.0 * getCount()),
getMax() / 10.0);
}
}
}