1brc/src/main/java/dev/morling/onebrc/CalculateAverage_ebarlas.java

356 lines
14 KiB
Java

/*
* Copyright 2023 The original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.morling.onebrc;
import sun.misc.Unsafe;
import java.io.IOException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.List;
import java.util.TreeMap;
public class CalculateAverage_ebarlas {
private static final int MAX_KEY_SIZE = 100;
private static final int HASH_FACTOR = 433;
private static final int HASH_TBL_SIZE = 16_383; // range of allowed hash values, inclusive
private static final Unsafe UNSAFE = makeUnsafe();
private static Unsafe makeUnsafe() {
try {
var f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
return (Unsafe) f.get(null);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) throws IOException, InterruptedException {
var path = Paths.get("measurements.txt");
var numPartitions = Math.max(8, Runtime.getRuntime().availableProcessors());
var channel = FileChannel.open(path, StandardOpenOption.READ);
var partitionSize = channel.size() / numPartitions;
var partitions = new Partition[numPartitions];
var threads = new Thread[numPartitions];
for (int i = 0; i < numPartitions; i++) {
var pIdx = i;
var pStart = pIdx * partitionSize;
var pEnd = pIdx == numPartitions - 1
? channel.size() // last partition might be slightly larger
: pStart + partitionSize;
var pSize = pEnd - pStart;
Runnable r = () -> {
try {
var buffer = channel.map(FileChannel.MapMode.READ_ONLY, pStart, pSize).order(ByteOrder.LITTLE_ENDIAN);
partitions[pIdx] = processBuffer(buffer, pIdx == 0);
}
catch (IOException e) {
throw new RuntimeException(e);
}
};
threads[i] = new Thread(r);
threads[i].start();
}
for (var thread : threads) {
thread.join();
}
var partitionList = List.of(partitions);
foldFootersAndHeaders(partitionList);
printResults(foldStats(partitionList));
}
private static void printResults(Stats[] stats) { // adheres to Gunnar's reference code
var result = new TreeMap<String, String>();
for (var st : stats) {
if (st != null) {
var key = new String(convert(st.keyAddr, st.keyLen, st.lastBytes), StandardCharsets.UTF_8);
result.put(key, format(st));
}
}
System.out.println(result);
}
private static byte[] convert(long keyAddr, int keyLen, int keyLastBytes) {
var len = keyLastBytes == 4
? keyLen * 4 // fully packed
: (keyLen - 1) * 4 + keyLastBytes; // last int partially packed
var bytes = new byte[len];
var idx = 0;
for (long i = 0; i < keyLen; i++) {
var offset = i << 2;
var n = UNSAFE.getInt(keyAddr + offset);
var bound = i == keyLen - 1 ? keyLastBytes : 4;
for (int j = 0; j < bound; j++) {
bytes[idx++] = (byte) (n & 0xFF);
n >>>= 8;
}
}
return bytes;
}
private static String format(Stats st) { // adheres to expected output format
return round(st.min / 10.0) + "/" + round((st.sum / 10.0) / st.count) + "/" + round(st.max / 10.0);
}
private static double round(double value) { // Gunnar's round function
return Math.round(value * 10.0) / 10.0;
}
private static Stats[] foldStats(List<Partition> partitions) { // fold stats from all partitions into first partition
var target = partitions.getFirst().stats;
for (int i = 1; i < partitions.size(); i++) {
var current = partitions.get(i).stats;
for (int j = 0; j < current.length; j++) {
if (current[j] != null) {
var t = findInTable(target, current[j].hash, current[j].keyAddr, current[j].keyLen, current[j].lastBytes);
t.min = Math.min(t.min, current[j].min);
t.max = Math.max(t.max, current[j].max);
t.sum += current[j].sum;
t.count += current[j].count;
}
}
}
return target;
}
private static void foldFootersAndHeaders(List<Partition> partitions) { // fold footers and headers into prev partition
for (int i = 1; i < partitions.size(); i++) {
var pNext = partitions.get(i);
var pPrev = partitions.get(i - 1);
var merged = mergeFooterAndHeader(pPrev.footer, pNext.header);
if (merged != null && merged.length != 0) {
if (merged[merged.length - 1] == '\n') { // fold into prev partition
doProcessBuffer(ByteBuffer.wrap(merged).order(ByteOrder.LITTLE_ENDIAN), true, pPrev.stats);
}
else { // no newline appeared in partition, carry forward
pNext.footer = merged;
}
}
}
}
private static byte[] mergeFooterAndHeader(byte[] footer, byte[] header) {
if (footer == null) {
return header;
}
if (header == null) {
return footer;
}
var merged = new byte[footer.length + header.length];
System.arraycopy(footer, 0, merged, 0, footer.length);
System.arraycopy(header, 0, merged, footer.length, header.length);
return merged;
}
private static Partition processBuffer(ByteBuffer buffer, boolean first) {
return doProcessBuffer(buffer, first, new Stats[HASH_TBL_SIZE + 1]);
}
private static Partition doProcessBuffer(ByteBuffer buffer, boolean first, Stats[] stats) {
var header = first ? null : readHeader(buffer);
var keyStart = reallyDoProcessBuffer(buffer, stats);
var footer = keyStart < buffer.limit() ? readFooter(buffer, keyStart) : null;
return new Partition(header, footer, stats);
}
private static int reallyDoProcessBuffer(ByteBuffer buffer, Stats[] stats) {
long keyBaseAddr = UNSAFE.allocateMemory(MAX_KEY_SIZE);
int keyStart = 0; // start of key in buffer used for footer calc
try { // abort with exception to allow optimistic line processing
while (true) { // one line per iteration
keyStart = buffer.position(); // preserve line start
int keyHash = 0; // key hash code
long keyAddr = keyBaseAddr; // address for next int
int keyArrLen = 0; // number of key 4-byte ints
int keyLastBytes; // occupancy in last byte (1, 2, 3, or 4)
int val; // temperature value
while (true) {
int n = buffer.getInt();
byte b0 = (byte) (n & 0xFF);
byte b1 = (byte) ((n >> 8) & 0xFF);
byte b2 = (byte) ((n >> 16) & 0xFF);
byte b3 = (byte) ((n >> 24) & 0xFF);
if (b0 == ';') { // ...;1.1
val = getVal(buffer, b1, b2, b3, buffer.get());
keyLastBytes = 4;
break;
}
else if (b1 == ';') { // ...a;1.1
val = getVal(buffer, b2, b3, buffer.get(), buffer.get());
UNSAFE.putInt(keyAddr, b0);
keyLastBytes = 1;
keyArrLen++;
keyHash = HASH_FACTOR * keyHash + b0;
break;
}
else if (b2 == ';') { // ...ab;1.1
val = getVal(buffer, b3, buffer.get(), buffer.get(), buffer.get());
UNSAFE.putInt(keyAddr, n & 0x0000FFFF);
keyLastBytes = 2;
keyArrLen++;
keyHash = HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1;
break;
}
else if (b3 == ';') { // ...abc;1.1
UNSAFE.putInt(keyAddr, n & 0x00FFFFFF);
keyLastBytes = 3;
keyArrLen++;
keyHash = HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1) + b2;
n = buffer.getInt();
b0 = (byte) (n & 0xFF);
b1 = (byte) ((n >> 8) & 0xFF);
b2 = (byte) ((n >> 16) & 0xFF);
b3 = (byte) ((n >> 24) & 0xFF);
val = getVal(buffer, b0, b1, b2, b3);
break;
}
else {
UNSAFE.putInt(keyAddr, n);
keyArrLen++;
keyAddr += 4;
keyHash = HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1) + b2) + b3;
}
}
var idx = keyHash & HASH_TBL_SIZE;
var st = stats[idx];
if (st == null) { // nothing in table, eagerly claim spot
st = stats[idx] = newStats(keyBaseAddr, keyArrLen, keyLastBytes, keyHash);
}
else if (!equals(st.keyAddr, st.keyLen, keyBaseAddr, keyArrLen)) {
st = findInTable(stats, keyHash, keyBaseAddr, keyArrLen, keyLastBytes);
}
st.min = Math.min(st.min, val);
st.max = Math.max(st.max, val);
st.sum += val;
st.count++;
}
}
catch (BufferUnderflowException ignore) {
}
return keyStart;
}
private static boolean equals(long key1, int len1, long key2, int len2) {
if (len1 != len2) {
return false;
}
for (long i = 0; i < len1; i++) {
var offset = i << 2;
if (UNSAFE.getInt(key1 + offset) != UNSAFE.getInt(key2 + offset)) {
return false;
}
}
return true;
}
private static int getVal(ByteBuffer buffer, byte b0, byte b1, byte b2, byte b3) {
if (b0 == '-') {
if (b2 != '.') { // 6 bytes: -dd.dn
var b = buffer.get();
buffer.get(); // newline
return -(((b1 - '0') * 10 + (b2 - '0')) * 10 + (b - '0'));
}
else { // 5 bytes: -d.dn
buffer.get(); // newline
return -((b1 - '0') * 10 + (b3 - '0'));
}
}
else {
if (b1 != '.') { // 5 bytes: dd.dn
buffer.get(); // newline
return ((b0 - '0') * 10 + (b1 - '0')) * 10 + (b3 - '0');
}
else { // 4 bytes: d.dn
return (b0 - '0') * 10 + (b2 - '0');
}
}
}
private static Stats findInTable(Stats[] stats, int hash, long keyAddr, int keyLen, int keyLastBytes) { // open-addressing scan
var idx = hash & HASH_TBL_SIZE;
var st = stats[idx];
while (st != null && !equals(st.keyAddr, st.keyLen, keyAddr, keyLen)) {
idx = (idx + 1) % (HASH_TBL_SIZE + 1);
st = stats[idx];
}
if (st != null) {
return st;
}
return stats[idx] = newStats(keyAddr, keyLen, keyLastBytes, hash);
}
private static Stats newStats(long keyAddr, int keyLen, int keyLastBytes, int hash) {
var bytes = keyLen << 2;
long k = UNSAFE.allocateMemory(bytes);
UNSAFE.copyMemory(keyAddr, k, bytes);
return new Stats(k, keyLen, keyLastBytes, hash);
}
private static byte[] readFooter(ByteBuffer buffer, int lineStart) { // read from line start to current pos (end-of-input)
var footer = new byte[buffer.limit() - lineStart];
buffer.get(lineStart, footer, 0, footer.length);
return footer;
}
private static byte[] readHeader(ByteBuffer buffer) { // read up to and including first newline (or end-of-input)
while (buffer.hasRemaining() && buffer.get() != '\n')
;
var header = new byte[buffer.position()];
buffer.get(0, header, 0, header.length);
return header;
}
private static class Partition {
byte[] header;
byte[] footer;
Stats[] stats;
Partition(byte[] header, byte[] footer, Stats[] stats) {
this.header = header;
this.footer = footer;
this.stats = stats;
}
}
private static class Stats { // min, max, and sum values are modeled with integral types that represent tenths of a unit
final long keyAddr; // address of 4-byte integer array
final int keyLen; // number of 4-byte integers starting at address
final int lastBytes; // number of bytes packed into last key int (1, 2, 3 or 4)
final int hash;
int min = Integer.MAX_VALUE;
int max = Integer.MIN_VALUE;
long sum;
long count;
Stats(long keyAddr, int keyLen, int lastBytes, int hash) {
this.keyAddr = keyAddr;
this.keyLen = keyLen;
this.lastBytes = lastBytes;
this.hash = hash;
}
}
}