armandino: second attempt (#445)
This commit is contained in:
parent
b1e6a120a4
commit
7bd2df7c59
@ -16,5 +16,5 @@
|
||||
#
|
||||
|
||||
|
||||
JAVA_OPTS=""
|
||||
JAVA_OPTS="--enable-preview -da -dsa -Xms128m -Xmx128m -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -XX:+AlwaysPreTouch"
|
||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_armandino
|
||||
|
@ -15,188 +15,143 @@
|
||||
*/
|
||||
package dev.morling.onebrc;
|
||||
|
||||
import sun.misc.Unsafe;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.PrintStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.lang.foreign.Arena;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.channels.FileChannel;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Objects;
|
||||
import java.util.TreeMap;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static java.nio.channels.FileChannel.MapMode.READ_ONLY;
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static java.util.stream.Collectors.toMap;
|
||||
|
||||
public class CalculateAverage_armandino {
|
||||
|
||||
private static final String FILE = "./measurements.txt";
|
||||
private static final Path FILE = Path.of("./measurements.txt");
|
||||
|
||||
private static final int MAX_KEY_LENGTH = 100;
|
||||
private static final int NUM_CHUNKS = Math.max(8, Runtime.getRuntime().availableProcessors());
|
||||
private static final int INITIAL_MAP_CAPACITY = 8192;
|
||||
private static final byte SEMICOLON = 59;
|
||||
private static final byte NL = 10;
|
||||
private static final byte DOT = 46;
|
||||
private static final byte MINUS = 45;
|
||||
private static final byte ZERO_DIGIT = 48;
|
||||
private static final Unsafe UNSAFE = getUnsafe();
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
Aggregator aggregator = new Aggregator();
|
||||
aggregator.process();
|
||||
aggregator.printStats();
|
||||
var channel = FileChannel.open(FILE, StandardOpenOption.READ);
|
||||
|
||||
var results = Arrays.stream(split(channel)).parallel()
|
||||
.map(chunk -> new ChunkProcessor().process(chunk.start, chunk.end))
|
||||
.flatMap(SimpleMap::stream)
|
||||
.collect(toMap(Stats::getKey, s -> s, CalculateAverage_armandino::mergeStats, TreeMap::new));
|
||||
|
||||
print(results.values());
|
||||
}
|
||||
|
||||
private static class Aggregator {
|
||||
private static Stats mergeStats(final Stats x, final Stats y) {
|
||||
x.min = Math.min(x.min, y.min);
|
||||
x.max = Math.max(x.max, y.max);
|
||||
x.count += y.count;
|
||||
x.sum += y.sum;
|
||||
return x;
|
||||
}
|
||||
|
||||
private final Map<Integer, Stats> map = new ConcurrentHashMap<>(2048);
|
||||
private static class ChunkProcessor {
|
||||
private final SimpleMap map = new SimpleMap(INITIAL_MAP_CAPACITY);
|
||||
|
||||
private record Chunk(long start, long end) {
|
||||
}
|
||||
private SimpleMap process(final long chunkStart, final long chunkEnd) {
|
||||
long i = chunkStart;
|
||||
while (i < chunkEnd) {
|
||||
final long keyAddress = i;
|
||||
int keyHash = 0;
|
||||
int measurement = 0;
|
||||
byte b;
|
||||
|
||||
void process() throws Exception {
|
||||
var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
|
||||
final Chunk[] chunks = split(channel);
|
||||
final Thread[] threads = new Thread[chunks.length];
|
||||
while ((b = UNSAFE.getByte(i++)) != SEMICOLON) {
|
||||
keyHash = 31 * keyHash + b;
|
||||
}
|
||||
|
||||
for (int i = 0; i < chunks.length; i++) {
|
||||
final Chunk chunk = chunks[i];
|
||||
final int keyLength = (int) (i - keyAddress - 1);
|
||||
|
||||
threads[i] = Thread.ofVirtual().start(() -> {
|
||||
try {
|
||||
var bb = channel.map(READ_ONLY, chunk.start, chunk.end - chunk.start);
|
||||
process(bb);
|
||||
if ((b = UNSAFE.getByte(i++)) == MINUS) {
|
||||
while ((b = UNSAFE.getByte(i++)) != DOT) {
|
||||
measurement = measurement * 10 + b - ZERO_DIGIT;
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (Thread t : threads) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
|
||||
private static Chunk[] split(final FileChannel channel) throws IOException {
|
||||
final long fileSize = channel.size();
|
||||
if (fileSize < 10000) {
|
||||
return new Chunk[]{ new Chunk(0, fileSize) };
|
||||
}
|
||||
|
||||
final int numChunks = 8;
|
||||
final long chunkSize = fileSize / numChunks;
|
||||
final var chunks = new Chunk[numChunks];
|
||||
|
||||
for (int i = 0; i < numChunks; i++) {
|
||||
long start = 0;
|
||||
long end = chunkSize;
|
||||
|
||||
if (i > 0) {
|
||||
start = chunks[i - 1].end + 1;
|
||||
end = Math.min(start + chunkSize, fileSize);
|
||||
}
|
||||
|
||||
end = end == fileSize ? end : seekNextNewline(channel, end);
|
||||
chunks[i] = new Chunk(start, end);
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
private static long seekNextNewline(final FileChannel channel, final long end) throws IOException {
|
||||
var bb = ByteBuffer.allocate(MAX_KEY_LENGTH);
|
||||
channel.position(end).read(bb);
|
||||
|
||||
for (int i = 0; i < bb.limit(); i++) {
|
||||
if (bb.get(i) == NL) {
|
||||
return end + i;
|
||||
}
|
||||
}
|
||||
|
||||
throw new IllegalStateException("Couldn't find next newline");
|
||||
}
|
||||
|
||||
private void process(final ByteBuffer bb) {
|
||||
final var sample = new Sample();
|
||||
var isKey = true;
|
||||
|
||||
for (long i = 0, sz = bb.limit(); i < sz; i++) {
|
||||
|
||||
final byte b = bb.get();
|
||||
|
||||
if (b == SEMICOLON) {
|
||||
isKey = false;
|
||||
}
|
||||
else if (b == NL) {
|
||||
isKey = true;
|
||||
addSample(sample);
|
||||
sample.reset();
|
||||
}
|
||||
else if (isKey) {
|
||||
sample.pushKey(b);
|
||||
}
|
||||
else if (b == DOT) {
|
||||
// skip
|
||||
}
|
||||
else if (b == MINUS) {
|
||||
sample.sign = -1;
|
||||
b = UNSAFE.getByte(i);
|
||||
measurement = measurement * 10 + b - ZERO_DIGIT;
|
||||
measurement = -measurement;
|
||||
i += 2;
|
||||
}
|
||||
else {
|
||||
sample.pushMeasurement(b);
|
||||
measurement = b - ZERO_DIGIT; // D1
|
||||
b = UNSAFE.getByte(i); // dot or D2
|
||||
|
||||
if (b == DOT) {
|
||||
measurement = measurement * 10 + UNSAFE.getByte(i + 1) - ZERO_DIGIT; // F
|
||||
i += 3;
|
||||
}
|
||||
else {
|
||||
measurement = measurement * 10 + b - ZERO_DIGIT; // D2
|
||||
measurement = measurement * 10 + UNSAFE.getByte(i + 2) - ZERO_DIGIT; // F
|
||||
i += 4; // skip NL
|
||||
}
|
||||
}
|
||||
|
||||
final Stats stats = map.putStats(keyHash, keyAddress, keyLength);
|
||||
stats.min = Math.min(stats.min, measurement);
|
||||
stats.max = Math.max(stats.max, measurement);
|
||||
stats.sum += measurement;
|
||||
stats.count++;
|
||||
}
|
||||
}
|
||||
|
||||
private void addSample(final Sample sample) {
|
||||
final Stats stats = map.computeIfAbsent(sample.keyHash,
|
||||
k -> new Stats(new String(sample.keyBytes, 0, sample.keyLength, UTF_8)));
|
||||
|
||||
final var val = sample.getMeasurement();
|
||||
|
||||
if (val < stats.min)
|
||||
stats.min = val;
|
||||
|
||||
if (val > stats.max)
|
||||
stats.max = val;
|
||||
|
||||
stats.sum += val;
|
||||
stats.count++;
|
||||
}
|
||||
|
||||
void printStats() {
|
||||
var sorted = new ArrayList<>(map.values());
|
||||
Collections.sort(sorted);
|
||||
|
||||
int size = sorted.size();
|
||||
|
||||
System.out.print('{');
|
||||
|
||||
for (Stats stats : sorted) {
|
||||
stats.print(System.out);
|
||||
if (--size > 0) {
|
||||
System.out.print(", ");
|
||||
}
|
||||
}
|
||||
System.out.println('}');
|
||||
return map;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Stats implements Comparable<Stats> {
|
||||
private final String city;
|
||||
private String key;
|
||||
private final byte[] keyBytes;
|
||||
private final int keyLength;
|
||||
private final int keyHash;
|
||||
private int min = Integer.MAX_VALUE;
|
||||
private int max = Integer.MIN_VALUE;
|
||||
private long sum;
|
||||
private int count;
|
||||
private long sum;
|
||||
|
||||
private Stats(String city) {
|
||||
this.city = city;
|
||||
private Stats(long keyAddress, int keyLength, int keyHash) {
|
||||
this.keyLength = keyLength;
|
||||
this.keyBytes = new byte[keyLength];
|
||||
this.keyHash = keyHash;
|
||||
|
||||
for (int i = 0; i < keyLength; i++) {
|
||||
keyBytes[i] = UNSAFE.getByte(keyAddress++);
|
||||
}
|
||||
}
|
||||
|
||||
String getKey() {
|
||||
if (key == null) {
|
||||
key = new String(keyBytes, 0, keyLength, UTF_8);
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(final Stats o) {
|
||||
return city.compareTo(o.city);
|
||||
return getKey().compareTo(o.getKey());
|
||||
}
|
||||
|
||||
void print(final PrintStream out) {
|
||||
out.print(city);
|
||||
out.print(key);
|
||||
out.print('=');
|
||||
out.print(round(min / 10f));
|
||||
out.print('/');
|
||||
@ -210,32 +165,148 @@ public class CalculateAverage_armandino {
|
||||
}
|
||||
}
|
||||
|
||||
private static class Sample {
|
||||
private final byte[] keyBytes = new byte[MAX_KEY_LENGTH];
|
||||
private int keyLength;
|
||||
private int keyHash;
|
||||
private int measurement;
|
||||
private int sign = 1;
|
||||
private static void print(final Collection<Stats> sorted) {
|
||||
int size = sorted.size();
|
||||
System.out.print('{');
|
||||
for (Stats stats : sorted) {
|
||||
stats.print(System.out);
|
||||
if (--size > 0) {
|
||||
System.out.print(", ");
|
||||
}
|
||||
}
|
||||
System.out.println('}');
|
||||
}
|
||||
|
||||
void pushKey(byte b) {
|
||||
keyBytes[keyLength++] = b;
|
||||
keyHash = 31 * keyHash + b;
|
||||
private static Chunk[] split(final FileChannel channel) throws IOException {
|
||||
final long fileSize = channel.size();
|
||||
long start = channel.map(READ_ONLY, 0, fileSize, Arena.global()).address();
|
||||
final long endAddress = start + fileSize;
|
||||
if (fileSize < 10000) {
|
||||
return new Chunk[]{ new Chunk(start, endAddress) };
|
||||
}
|
||||
|
||||
void pushMeasurement(byte b) {
|
||||
final int i = b - '0';
|
||||
measurement = measurement * 10 + i;
|
||||
final long chunkSize = fileSize / NUM_CHUNKS;
|
||||
final var chunks = new Chunk[NUM_CHUNKS];
|
||||
long end = start + chunkSize;
|
||||
|
||||
for (int i = 0; i < NUM_CHUNKS; i++) {
|
||||
if (i > 0) {
|
||||
start = chunks[i - 1].end;
|
||||
end = Math.min(start + chunkSize, endAddress);
|
||||
}
|
||||
if (end < endAddress) {
|
||||
while (UNSAFE.getByte(end) != NL) {
|
||||
end++;
|
||||
}
|
||||
end++;
|
||||
}
|
||||
chunks[i] = new Chunk(start, end);
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
private record Chunk(long start, long end) {
|
||||
}
|
||||
|
||||
private static Unsafe getUnsafe() {
|
||||
try {
|
||||
Field unsafe = Unsafe.class.getDeclaredField("theUnsafe");
|
||||
unsafe.setAccessible(true);
|
||||
return (Unsafe) unsafe.get(null);
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static class SimpleMap {
|
||||
private Stats[] table;
|
||||
|
||||
SimpleMap(int initialCapacity) {
|
||||
table = new Stats[initialCapacity];
|
||||
}
|
||||
|
||||
int getMeasurement() {
|
||||
return sign * measurement;
|
||||
Stream<Stats> stream() {
|
||||
return Arrays.stream(table).filter(Objects::nonNull);
|
||||
}
|
||||
|
||||
void reset() {
|
||||
keyHash = 0;
|
||||
keyLength = 0;
|
||||
measurement = 0;
|
||||
sign = 1;
|
||||
private void resize() {
|
||||
var copy = new SimpleMap(table.length * 2);
|
||||
for (Stats s : table) {
|
||||
if (s != null) {
|
||||
final int pos = (copy.table.length - 1) & s.keyHash;
|
||||
int i = pos;
|
||||
|
||||
if (copy.table[i] == null) {
|
||||
copy.table[i] = s;
|
||||
continue;
|
||||
}
|
||||
|
||||
while (i < copy.table.length && copy.table[i] != null) {
|
||||
i++;
|
||||
}
|
||||
if (i == copy.table.length) {
|
||||
i = pos;
|
||||
while (i >= 0 && copy.table[i] != null) {
|
||||
i--;
|
||||
}
|
||||
}
|
||||
if (i < 0) {
|
||||
// shouldn't happen because put() is called after increasing size
|
||||
throw new IllegalStateException("table is full");
|
||||
}
|
||||
copy.table[i] = s;
|
||||
}
|
||||
}
|
||||
table = copy.table;
|
||||
}
|
||||
|
||||
Stats putStats(final int keyHash, final long keyAddress, final int keyLength) {
|
||||
final int pos = (table.length - 1) & keyHash;
|
||||
|
||||
Stats stats = table[pos];
|
||||
if (stats == null)
|
||||
return createAt(table, keyAddress, keyLength, keyHash, pos);
|
||||
if (stats.keyHash == keyHash && keysEqual(stats, keyAddress, keyLength))
|
||||
return stats;
|
||||
|
||||
int i = pos;
|
||||
while (++i < table.length) {
|
||||
stats = table[i];
|
||||
if (stats == null)
|
||||
return createAt(table, keyAddress, keyLength, keyHash, i);
|
||||
if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength))
|
||||
return stats;
|
||||
}
|
||||
|
||||
i = pos;
|
||||
while (i-- > 0) {
|
||||
stats = table[i];
|
||||
if (stats == null)
|
||||
return createAt(table, keyAddress, keyLength, keyHash, i);
|
||||
if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength))
|
||||
return stats;
|
||||
}
|
||||
resize();
|
||||
return putStats(keyHash, keyAddress, keyLength);
|
||||
}
|
||||
|
||||
private boolean keysEqual(Stats stats, long keyAddress, final int keyLength) {
|
||||
if (stats.keyLength != keyLength) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < keyLength; i++) {
|
||||
if (stats.keyBytes[i] != UNSAFE.getByte(keyAddress++)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private static Stats createAt(Stats[] table, long keyAddress, int keyLength, int key, int i) {
|
||||
Stats stats = new Stats(keyAddress, keyLength, key);
|
||||
table[i] = stats;
|
||||
return stats;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user