1brc/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java

313 lines
10 KiB
Java
Raw Normal View History

2024-01-04 08:26:17 +01:00
/*
* Copyright 2023 The original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.morling.onebrc;
2024-01-16 22:04:37 +01:00
import sun.misc.Unsafe;
2024-01-04 08:26:17 +01:00
import java.io.IOException;
import java.io.PrintStream;
2024-01-16 22:04:37 +01:00
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
2024-01-04 08:26:17 +01:00
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
2024-01-16 22:04:37 +01:00
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import java.util.TreeMap;
import java.util.stream.Stream;
2024-01-04 08:26:17 +01:00
import static java.nio.channels.FileChannel.MapMode.READ_ONLY;
import static java.nio.charset.StandardCharsets.UTF_8;
2024-01-16 22:04:37 +01:00
import static java.util.stream.Collectors.toMap;
2024-01-04 08:26:17 +01:00
public class CalculateAverage_armandino {
2024-01-16 22:04:37 +01:00
private static final Path FILE = Path.of("./measurements.txt");
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
private static final int NUM_CHUNKS = Math.max(8, Runtime.getRuntime().availableProcessors());
private static final int INITIAL_MAP_CAPACITY = 8192;
2024-01-04 08:26:17 +01:00
private static final byte SEMICOLON = 59;
private static final byte NL = 10;
private static final byte DOT = 46;
private static final byte MINUS = 45;
2024-01-16 22:04:37 +01:00
private static final byte ZERO_DIGIT = 48;
private static final Unsafe UNSAFE = getUnsafe();
2024-01-04 08:26:17 +01:00
public static void main(String[] args) throws Exception {
2024-01-16 22:04:37 +01:00
var channel = FileChannel.open(FILE, StandardOpenOption.READ);
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
var results = Arrays.stream(split(channel)).parallel()
.map(chunk -> new ChunkProcessor().process(chunk.start, chunk.end))
.flatMap(SimpleMap::stream)
.collect(toMap(Stats::getKey, s -> s, CalculateAverage_armandino::mergeStats, TreeMap::new));
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
print(results.values());
}
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
private static Stats mergeStats(final Stats x, final Stats y) {
x.min = Math.min(x.min, y.min);
x.max = Math.max(x.max, y.max);
x.count += y.count;
x.sum += y.sum;
return x;
}
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
private static class ChunkProcessor {
private final SimpleMap map = new SimpleMap(INITIAL_MAP_CAPACITY);
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
private SimpleMap process(final long chunkStart, final long chunkEnd) {
long i = chunkStart;
while (i < chunkEnd) {
final long keyAddress = i;
int keyHash = 0;
int measurement = 0;
byte b;
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
while ((b = UNSAFE.getByte(i++)) != SEMICOLON) {
keyHash = 31 * keyHash + b;
2024-01-04 08:26:17 +01:00
}
2024-01-16 22:04:37 +01:00
final int keyLength = (int) (i - keyAddress - 1);
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
if ((b = UNSAFE.getByte(i++)) == MINUS) {
while ((b = UNSAFE.getByte(i++)) != DOT) {
measurement = measurement * 10 + b - ZERO_DIGIT;
}
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
b = UNSAFE.getByte(i);
measurement = measurement * 10 + b - ZERO_DIGIT;
measurement = -measurement;
i += 2;
2024-01-04 08:26:17 +01:00
}
else {
2024-01-16 22:04:37 +01:00
measurement = b - ZERO_DIGIT; // D1
b = UNSAFE.getByte(i); // dot or D2
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
if (b == DOT) {
measurement = measurement * 10 + UNSAFE.getByte(i + 1) - ZERO_DIGIT; // F
i += 3;
}
else {
measurement = measurement * 10 + b - ZERO_DIGIT; // D2
measurement = measurement * 10 + UNSAFE.getByte(i + 2) - ZERO_DIGIT; // F
i += 4; // skip NL
}
2024-01-04 08:26:17 +01:00
}
2024-01-16 22:04:37 +01:00
final Stats stats = map.putStats(keyHash, keyAddress, keyLength);
stats.min = Math.min(stats.min, measurement);
stats.max = Math.max(stats.max, measurement);
stats.sum += measurement;
stats.count++;
2024-01-04 08:26:17 +01:00
}
2024-01-16 22:04:37 +01:00
return map;
2024-01-04 08:26:17 +01:00
}
}
private static class Stats implements Comparable<Stats> {
2024-01-16 22:04:37 +01:00
private String key;
private final byte[] keyBytes;
private final int keyLength;
private final int keyHash;
2024-01-04 08:26:17 +01:00
private int min = Integer.MAX_VALUE;
private int max = Integer.MIN_VALUE;
private int count;
2024-01-16 22:04:37 +01:00
private long sum;
private Stats(long keyAddress, int keyLength, int keyHash) {
this.keyLength = keyLength;
this.keyBytes = new byte[keyLength];
this.keyHash = keyHash;
for (int i = 0; i < keyLength; i++) {
keyBytes[i] = UNSAFE.getByte(keyAddress++);
}
}
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
String getKey() {
if (key == null) {
key = new String(keyBytes, 0, keyLength, UTF_8);
}
return key;
2024-01-04 08:26:17 +01:00
}
@Override
public int compareTo(final Stats o) {
2024-01-16 22:04:37 +01:00
return getKey().compareTo(o.getKey());
2024-01-04 08:26:17 +01:00
}
void print(final PrintStream out) {
2024-01-16 22:04:37 +01:00
out.print(key);
2024-01-04 08:26:17 +01:00
out.print('=');
out.print(round(min / 10f));
out.print('/');
out.print(round((sum / 10f) / count));
out.print('/');
out.print(round(max) / 10f);
}
private static double round(double value) {
return Math.round(value * 10.0) / 10.0;
}
}
2024-01-16 22:04:37 +01:00
private static void print(final Collection<Stats> sorted) {
int size = sorted.size();
System.out.print('{');
for (Stats stats : sorted) {
stats.print(System.out);
if (--size > 0) {
System.out.print(", ");
}
}
System.out.println('}');
}
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
private static Chunk[] split(final FileChannel channel) throws IOException {
final long fileSize = channel.size();
long start = channel.map(READ_ONLY, 0, fileSize, Arena.global()).address();
final long endAddress = start + fileSize;
if (fileSize < 10000) {
return new Chunk[]{ new Chunk(start, endAddress) };
2024-01-04 08:26:17 +01:00
}
2024-01-16 22:04:37 +01:00
final long chunkSize = fileSize / NUM_CHUNKS;
final var chunks = new Chunk[NUM_CHUNKS];
long end = start + chunkSize;
for (int i = 0; i < NUM_CHUNKS; i++) {
if (i > 0) {
start = chunks[i - 1].end;
end = Math.min(start + chunkSize, endAddress);
}
if (end < endAddress) {
while (UNSAFE.getByte(end) != NL) {
end++;
}
end++;
}
chunks[i] = new Chunk(start, end);
2024-01-04 08:26:17 +01:00
}
2024-01-16 22:04:37 +01:00
return chunks;
}
private record Chunk(long start, long end) {
}
private static Unsafe getUnsafe() {
try {
Field unsafe = Unsafe.class.getDeclaredField("theUnsafe");
unsafe.setAccessible(true);
return (Unsafe) unsafe.get(null);
}
catch (Exception e) {
throw new RuntimeException(e);
}
}
private static class SimpleMap {
private Stats[] table;
2024-01-04 08:26:17 +01:00
2024-01-16 22:04:37 +01:00
SimpleMap(int initialCapacity) {
table = new Stats[initialCapacity];
}
Stream<Stats> stream() {
return Arrays.stream(table).filter(Objects::nonNull);
}
private void resize() {
var copy = new SimpleMap(table.length * 2);
for (Stats s : table) {
if (s != null) {
final int pos = (copy.table.length - 1) & s.keyHash;
int i = pos;
if (copy.table[i] == null) {
copy.table[i] = s;
continue;
}
while (i < copy.table.length && copy.table[i] != null) {
i++;
}
if (i == copy.table.length) {
i = pos;
while (i >= 0 && copy.table[i] != null) {
i--;
}
}
if (i < 0) {
// shouldn't happen because put() is called after increasing size
throw new IllegalStateException("table is full");
}
copy.table[i] = s;
}
}
table = copy.table;
}
Stats putStats(final int keyHash, final long keyAddress, final int keyLength) {
final int pos = (table.length - 1) & keyHash;
Stats stats = table[pos];
if (stats == null)
return createAt(table, keyAddress, keyLength, keyHash, pos);
if (stats.keyHash == keyHash && keysEqual(stats, keyAddress, keyLength))
return stats;
int i = pos;
while (++i < table.length) {
stats = table[i];
if (stats == null)
return createAt(table, keyAddress, keyLength, keyHash, i);
if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength))
return stats;
}
i = pos;
while (i-- > 0) {
stats = table[i];
if (stats == null)
return createAt(table, keyAddress, keyLength, keyHash, i);
if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength))
return stats;
}
resize();
return putStats(keyHash, keyAddress, keyLength);
}
private boolean keysEqual(Stats stats, long keyAddress, final int keyLength) {
if (stats.keyLength != keyLength) {
return false;
}
for (int i = 0; i < keyLength; i++) {
if (stats.keyBytes[i] != UNSAFE.getByte(keyAddress++)) {
return false;
}
}
return true;
2024-01-04 08:26:17 +01:00
}
2024-01-16 22:04:37 +01:00
private static Stats createAt(Stats[] table, long keyAddress, int keyLength, int key, int i) {
Stats stats = new Stats(keyAddress, keyLength, key);
table[i] = stats;
return stats;
2024-01-04 08:26:17 +01:00
}
}
}