405 lines
15 KiB
Java
405 lines
15 KiB
Java
/*
|
|
* Copyright 2023 The original authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
package dev.morling.onebrc;
|
|
|
|
import java.io.IOException;
|
|
import java.io.RandomAccessFile;
|
|
import java.nio.ByteBuffer;
|
|
import java.nio.ByteOrder;
|
|
import java.nio.channels.FileChannel;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.util.ArrayList;
|
|
import java.util.Arrays;
|
|
import java.util.List;
|
|
import java.util.TreeMap;
|
|
import java.util.concurrent.ForkJoinPool;
|
|
import java.util.concurrent.RecursiveTask;
|
|
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
|
|
import java.util.concurrent.atomic.AtomicReferenceArray;
|
|
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
|
|
import java.util.function.IntSupplier;
|
|
|
|
public class CalculateAverage_richardstartin {
|
|
|
|
private static final String FILE = "./measurements.txt";
|
|
|
|
private record ResultRow(double min, double mean, double max) {
|
|
public String toString() {
|
|
return round(min) + "/" + round(mean) + "/" + round(max);
|
|
}
|
|
|
|
private double round(double value) {
|
|
return Math.round(value * 10.0) / 10.0;
|
|
}
|
|
}
|
|
|
|
static String bufferToString(ByteBuffer slice) {
|
|
byte[] bytes = new byte[slice.limit()];
|
|
slice.get(0, bytes);
|
|
return new String(bytes, StandardCharsets.UTF_8);
|
|
}
|
|
|
|
static double parseTemperature(ByteBuffer slice) {
|
|
// credit: adapted from spullara's submission
|
|
int value = 0;
|
|
int negative = 1;
|
|
int i = 0;
|
|
while (i != slice.limit()) {
|
|
byte b = slice.get(i++);
|
|
switch (b) {
|
|
case '-':
|
|
negative = -1;
|
|
case '.':
|
|
break;
|
|
default:
|
|
value = 10 * value + (b - '0');
|
|
}
|
|
}
|
|
value *= negative;
|
|
return value / 10.0;
|
|
}
|
|
|
|
@FunctionalInterface
|
|
interface IndexedStringConsumer {
|
|
void accept(String value, int index);
|
|
}
|
|
|
|
/** Maps text to an integer encoding. Adapted from async-profiler. */
|
|
public static class Dictionary {
|
|
|
|
private static final int ROW_BITS = 12;
|
|
private static final int ROWS = (1 << ROW_BITS);
|
|
private static final int CELLS = 3;
|
|
private static final int TABLE_CAPACITY = (ROWS * CELLS);
|
|
|
|
private final Table table = new Table(nextBaseIndex());
|
|
|
|
private static final AtomicIntegerFieldUpdater<Dictionary> BASE_INDEX_UPDATER = AtomicIntegerFieldUpdater.newUpdater(Dictionary.class, "baseIndex");
|
|
volatile int baseIndex;
|
|
|
|
private void forEach(Table table, IndexedStringConsumer consumer) {
|
|
for (int i = 0; i < ROWS; i++) {
|
|
Row row = table.rows[i];
|
|
for (int j = 0; j < CELLS; j++) {
|
|
var slice = row.keys.get(j);
|
|
if (slice != null) {
|
|
consumer.accept(bufferToString(slice), table.index(i, j));
|
|
}
|
|
}
|
|
if (row.next != null) {
|
|
forEach(row.next, consumer);
|
|
}
|
|
}
|
|
}
|
|
|
|
public void forEach(IndexedStringConsumer consumer) {
|
|
forEach(this.table, consumer);
|
|
}
|
|
|
|
public int encode(int hash, ByteBuffer slice) {
|
|
Table table = this.table;
|
|
while (true) {
|
|
int rowIndex = Math.abs(hash) % ROWS;
|
|
Row row = table.rows[rowIndex];
|
|
for (int c = 0; c < CELLS; c++) {
|
|
ByteBuffer storedKey = row.keys.get(c);
|
|
if (storedKey == null) {
|
|
if (row.keys.compareAndSet(c, null, slice)) {
|
|
return table.index(rowIndex, c);
|
|
}
|
|
else {
|
|
storedKey = row.keys.get(c);
|
|
if (slice.equals(storedKey)) {
|
|
return table.index(rowIndex, c);
|
|
}
|
|
}
|
|
}
|
|
else if (slice.equals(storedKey)) {
|
|
return table.index(rowIndex, c);
|
|
}
|
|
}
|
|
table = row.getOrCreateNextTable(this::nextBaseIndex);
|
|
hash = Integer.rotateRight(hash, ROW_BITS);
|
|
}
|
|
}
|
|
|
|
private int nextBaseIndex() {
|
|
return BASE_INDEX_UPDATER.addAndGet(this, TABLE_CAPACITY);
|
|
}
|
|
|
|
private static final class Row {
|
|
|
|
private static final AtomicReferenceFieldUpdater<Row, Table> NEXT_TABLE_UPDATER = AtomicReferenceFieldUpdater.newUpdater(Row.class, Table.class, "next");
|
|
private final AtomicReferenceArray<ByteBuffer> keys = new AtomicReferenceArray<>(CELLS);
|
|
volatile Table next;
|
|
|
|
public Table getOrCreateNextTable(IntSupplier baseIndexSupplier) {
|
|
Table next = this.next;
|
|
if (next == null) {
|
|
Table newTable = new Table(baseIndexSupplier.getAsInt());
|
|
if (NEXT_TABLE_UPDATER.compareAndSet(this, null, newTable)) {
|
|
next = newTable;
|
|
}
|
|
else {
|
|
next = this.next;
|
|
}
|
|
}
|
|
return next;
|
|
}
|
|
}
|
|
|
|
private static final class Table {
|
|
|
|
final Row[] rows;
|
|
final int baseIndex;
|
|
|
|
private Table(int baseIndex) {
|
|
this.baseIndex = baseIndex;
|
|
this.rows = new Row[ROWS];
|
|
Arrays.setAll(rows, i -> new Row());
|
|
}
|
|
|
|
int index(int row, int col) {
|
|
return baseIndex + (col << ROW_BITS) + row;
|
|
}
|
|
}
|
|
}
|
|
|
|
private static long compilePattern(long repeat) {
|
|
return 0x101010101010101L * repeat;
|
|
}
|
|
|
|
private static long compilePattern(char delimiter) {
|
|
return compilePattern(delimiter & 0xFFL);
|
|
}
|
|
|
|
private static long compilePattern(byte delimiter) {
|
|
return compilePattern(delimiter & 0xFFL);
|
|
}
|
|
|
|
private static final long NEW_LINE = compilePattern((byte) '\n');
|
|
private static final long DELIMITER = compilePattern(';');
|
|
|
|
private static int firstInstance(long word, long pattern) {
|
|
long input = word ^ pattern;
|
|
long tmp = (input & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL;
|
|
tmp = ~(tmp | input | 0x7F7F7F7F7F7F7F7FL);
|
|
return Long.numberOfTrailingZeros(tmp) >>> 3;
|
|
}
|
|
|
|
private static int findLastNewLine(ByteBuffer buffer) {
|
|
return findLastNewLine(buffer, buffer.limit() - 1);
|
|
}
|
|
|
|
private static int findLastNewLine(ByteBuffer buffer, int offset) {
|
|
for (int i = offset; i >= 0; i--) {
|
|
if (buffer.get(i) == '\n') {
|
|
return i;
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
private static int findIndexOf(ByteBuffer buffer, int offset, long pattern) {
|
|
int i = offset;
|
|
for (; i + Long.BYTES < buffer.limit(); i += Long.BYTES) {
|
|
int index = firstInstance(buffer.getLong(i), pattern);
|
|
if (index != Long.BYTES) {
|
|
return i + index;
|
|
}
|
|
}
|
|
byte b = (byte) (pattern & 0xFF);
|
|
for (; i < buffer.limit(); i++) {
|
|
if (buffer.get(i) == b) {
|
|
return i;
|
|
}
|
|
}
|
|
return buffer.limit();
|
|
}
|
|
|
|
static class Page {
|
|
|
|
static final int PAGE_SIZE = 1024;
|
|
static final int PAGE_SHIFT = Integer.numberOfTrailingZeros(PAGE_SIZE);
|
|
static final int PAGE_MASK = PAGE_SIZE - 1;
|
|
private static final double[] PAGE_PROTOTYPE = new double[PAGE_SIZE * 4];
|
|
|
|
static {
|
|
for (int i = 0; i < PAGE_SIZE * 4; i += 4) {
|
|
PAGE_PROTOTYPE[i + 1] = Double.POSITIVE_INFINITY;
|
|
PAGE_PROTOTYPE[i + 2] = Double.NEGATIVE_INFINITY;
|
|
}
|
|
}
|
|
|
|
private static double[] newPage() {
|
|
return Arrays.copyOf(PAGE_PROTOTYPE, PAGE_PROTOTYPE.length);
|
|
}
|
|
|
|
static void update(List<double[]> pages, int position, double value) {
|
|
// find the page
|
|
int pageIndex = position >>> PAGE_SHIFT;
|
|
if (pageIndex >= pages.size()) {
|
|
for (int i = pages.size(); i <= pageIndex; i++) {
|
|
pages.add(null);
|
|
}
|
|
}
|
|
double[] page = pages.get(pageIndex);
|
|
if (page == null) {
|
|
page = newPage();
|
|
pages.set(pageIndex, page);
|
|
}
|
|
|
|
// update local aggregates
|
|
page[(position & PAGE_MASK) * 4]++; // count
|
|
page[(position & PAGE_MASK) * 4 + 1] = Math.min(page[(position & PAGE_MASK) * 4 + 1], value); // min
|
|
page[(position & PAGE_MASK) * 4 + 2] = Math.max(page[(position & PAGE_MASK) * 4 + 2], value); // min
|
|
page[(position & PAGE_MASK) * 4 + 3] += value; // sum
|
|
}
|
|
|
|
static ResultRow toResultRow(List<double[]> pages, int position) {
|
|
double[] page = pages.get(position >>> PAGE_SHIFT);
|
|
double count = page[(position & PAGE_MASK) * 4];
|
|
double min = page[(position & PAGE_MASK) * 4 + 1];
|
|
double max = page[(position & PAGE_MASK) * 4 + 2];
|
|
double sum = page[(position & PAGE_MASK) * 4 + 3];
|
|
return new ResultRow(min, sum / count, max);
|
|
}
|
|
}
|
|
|
|
private static class AggregationTask extends RecursiveTask<List<double[]>> {
|
|
|
|
private final Dictionary dictionary;
|
|
private final List<ByteBuffer> slices;
|
|
private final int min;
|
|
private final int max;
|
|
|
|
private AggregationTask(Dictionary dictionary, List<ByteBuffer> slices) {
|
|
this(dictionary, slices, 0, slices.size() - 1);
|
|
}
|
|
|
|
private AggregationTask(Dictionary dictionary, List<ByteBuffer> slices, int min, int max) {
|
|
this.dictionary = dictionary;
|
|
this.slices = slices;
|
|
this.min = min;
|
|
this.max = max;
|
|
}
|
|
|
|
private void computeSlice(ByteBuffer slice, List<double[]> pages) {
|
|
for (int offset = 0; offset < slice.limit();) {
|
|
int nextSeparator = findIndexOf(slice, offset, DELIMITER);
|
|
ByteBuffer key = slice.slice(offset, nextSeparator - offset).order(ByteOrder.LITTLE_ENDIAN);
|
|
// find the global dictionary code to aggregate,
|
|
// making this code global allows easy merging
|
|
int dictId = dictionary.encode(key.hashCode(), key);
|
|
|
|
offset = nextSeparator + 1;
|
|
int newLine = findIndexOf(slice, offset, NEW_LINE);
|
|
// parse the double
|
|
double d = parseTemperature(slice.slice(offset, newLine - offset));
|
|
|
|
Page.update(pages, dictId, d);
|
|
|
|
offset = newLine + 1;
|
|
}
|
|
}
|
|
|
|
private static void merge(List<double[]> contribution, List<double[]> aggregate) {
|
|
for (int i = aggregate.size(); i < contribution.size(); i++) {
|
|
aggregate.add(null);
|
|
}
|
|
for (int i = 0; i < contribution.size(); i++) {
|
|
if (aggregate.get(i) == null) {
|
|
aggregate.set(i, contribution.get(i));
|
|
}
|
|
else if (contribution.get(i) != null) {
|
|
double[] to = aggregate.get(i);
|
|
double[] from = contribution.get(i);
|
|
// todo won't vectorise - consider separating aggregates into distinct regions and apply
|
|
// loop fission (if this shows up in the profile)
|
|
for (int j = 0; j < to.length; j += 4) {
|
|
to[j] += from[j];
|
|
to[j + 1] = Math.min(to[j + 1], from[j + 1]);
|
|
to[j + 2] = Math.max(to[j + 2], from[j + 2]);
|
|
to[j + 3] += from[j + 3];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
@Override
|
|
protected List<double[]> compute() {
|
|
if (min == max) {
|
|
var pages = new ArrayList<double[]>();
|
|
var slice = slices.get(min);
|
|
computeSlice(slice, pages);
|
|
return pages;
|
|
}
|
|
else {
|
|
int mid = (min + max) / 2;
|
|
var low = new AggregationTask(dictionary, slices, min, mid);
|
|
var high = new AggregationTask(dictionary, slices, mid + 1, max);
|
|
var fork = high.fork();
|
|
var partial = low.compute();
|
|
merge(fork.join(), partial);
|
|
return partial;
|
|
}
|
|
}
|
|
}
|
|
|
|
public static void main(String[] args) throws IOException {
|
|
int maxChunkSize = 250 << 20; // 250MiB
|
|
try (var raf = new RandomAccessFile(FILE, "r");
|
|
var channel = raf.getChannel()) {
|
|
long size = channel.size();
|
|
// make as few mmap calls as possible subject to the 2GiB limit per buffer
|
|
List<ByteBuffer> rawBuffers = new ArrayList<>();
|
|
for (long offset = 0; offset < size - 1;) {
|
|
long end = Math.min(Integer.MAX_VALUE, size - offset);
|
|
ByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, offset, end)
|
|
.order(ByteOrder.LITTLE_ENDIAN);
|
|
boolean lastSlice = end != Integer.MAX_VALUE;
|
|
int limit = lastSlice
|
|
? (int) end
|
|
: findLastNewLine(buffer);
|
|
rawBuffers.add(buffer.limit(limit));
|
|
offset += limit;
|
|
}
|
|
|
|
// now slice them up for parallel processing
|
|
var slices = new ArrayList<ByteBuffer>();
|
|
for (ByteBuffer rawBuffer : rawBuffers) {
|
|
for (int offset = 0; offset < rawBuffer.limit();) {
|
|
int chunkSize = Math.min(rawBuffer.limit() - offset, maxChunkSize);
|
|
int target = offset + chunkSize;
|
|
int limit = target >= rawBuffer.limit()
|
|
? rawBuffer.limit()
|
|
: findLastNewLine(rawBuffer, target);
|
|
int adjustment = rawBuffer.get(offset) == '\n' ? 1 : 0;
|
|
var slice = rawBuffer.slice(offset + adjustment, limit - offset - adjustment).order(ByteOrder.LITTLE_ENDIAN);
|
|
slices.add(slice);
|
|
offset = limit;
|
|
}
|
|
}
|
|
|
|
var fjp = new ForkJoinPool(Runtime.getRuntime().availableProcessors());
|
|
Dictionary dictionary = new Dictionary();
|
|
var aggregates = fjp.submit(new AggregationTask(dictionary, slices)).join();
|
|
var map = new TreeMap<String, ResultRow>();
|
|
dictionary.forEach((key, index) -> map.put(key, Page.toResultRow(aggregates, index)));
|
|
System.out.println(map);
|
|
}
|
|
}
|
|
} |