From 7bd2df7c590773a497e1c67f86b8f7c91173e657 Mon Sep 17 00:00:00 2001 From: Arman Sharif Date: Tue, 16 Jan 2024 13:04:37 -0800 Subject: [PATCH] armandino: second attempt (#445) --- calculate_average_armandino.sh | 2 +- .../onebrc/CalculateAverage_armandino.java | 381 +++++++++++------- 2 files changed, 227 insertions(+), 156 deletions(-) diff --git a/calculate_average_armandino.sh b/calculate_average_armandino.sh index 719953d..6ac5c16 100755 --- a/calculate_average_armandino.sh +++ b/calculate_average_armandino.sh @@ -16,5 +16,5 @@ # -JAVA_OPTS="" +JAVA_OPTS="--enable-preview -da -dsa -Xms128m -Xmx128m -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -XX:+AlwaysPreTouch" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_armandino diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java b/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java index 21abbb1..dce3a33 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java @@ -15,188 +15,143 @@ */ package dev.morling.onebrc; +import sun.misc.Unsafe; + import java.io.IOException; import java.io.PrintStream; -import java.nio.ByteBuffer; +import java.lang.foreign.Arena; +import java.lang.reflect.Field; import java.nio.channels.FileChannel; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Arrays; +import java.util.Collection; +import java.util.Objects; +import java.util.TreeMap; +import java.util.stream.Stream; import static java.nio.channels.FileChannel.MapMode.READ_ONLY; import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.stream.Collectors.toMap; public class CalculateAverage_armandino { - private static final String FILE = "./measurements.txt"; + private static final Path FILE = Path.of("./measurements.txt"); - private static final int MAX_KEY_LENGTH = 100; + private static final int NUM_CHUNKS = Math.max(8, Runtime.getRuntime().availableProcessors()); + private static final int INITIAL_MAP_CAPACITY = 8192; private static final byte SEMICOLON = 59; private static final byte NL = 10; private static final byte DOT = 46; private static final byte MINUS = 45; + private static final byte ZERO_DIGIT = 48; + private static final Unsafe UNSAFE = getUnsafe(); public static void main(String[] args) throws Exception { - Aggregator aggregator = new Aggregator(); - aggregator.process(); - aggregator.printStats(); + var channel = FileChannel.open(FILE, StandardOpenOption.READ); + + var results = Arrays.stream(split(channel)).parallel() + .map(chunk -> new ChunkProcessor().process(chunk.start, chunk.end)) + .flatMap(SimpleMap::stream) + .collect(toMap(Stats::getKey, s -> s, CalculateAverage_armandino::mergeStats, TreeMap::new)); + + print(results.values()); } - private static class Aggregator { + private static Stats mergeStats(final Stats x, final Stats y) { + x.min = Math.min(x.min, y.min); + x.max = Math.max(x.max, y.max); + x.count += y.count; + x.sum += y.sum; + return x; + } - private final Map map = new ConcurrentHashMap<>(2048); + private static class ChunkProcessor { + private final SimpleMap map = new SimpleMap(INITIAL_MAP_CAPACITY); - private record Chunk(long start, long end) { - } + private SimpleMap process(final long chunkStart, final long chunkEnd) { + long i = chunkStart; + while (i < chunkEnd) { + final long keyAddress = i; + int keyHash = 0; + int measurement = 0; + byte b; - void process() throws Exception { - var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); - final Chunk[] chunks = split(channel); - final Thread[] threads = new Thread[chunks.length]; + while ((b = UNSAFE.getByte(i++)) != SEMICOLON) { + keyHash = 31 * keyHash + b; + } - for (int i = 0; i < chunks.length; i++) { - final Chunk chunk = chunks[i]; + final int keyLength = (int) (i - keyAddress - 1); - threads[i] = Thread.ofVirtual().start(() -> { - try { - var bb = channel.map(READ_ONLY, chunk.start, chunk.end - chunk.start); - process(bb); + if ((b = UNSAFE.getByte(i++)) == MINUS) { + while ((b = UNSAFE.getByte(i++)) != DOT) { + measurement = measurement * 10 + b - ZERO_DIGIT; } - catch (IOException e) { - throw new RuntimeException(e); - } - }); - } - for (Thread t : threads) { - t.join(); - } - } - - private static Chunk[] split(final FileChannel channel) throws IOException { - final long fileSize = channel.size(); - if (fileSize < 10000) { - return new Chunk[]{ new Chunk(0, fileSize) }; - } - - final int numChunks = 8; - final long chunkSize = fileSize / numChunks; - final var chunks = new Chunk[numChunks]; - - for (int i = 0; i < numChunks; i++) { - long start = 0; - long end = chunkSize; - - if (i > 0) { - start = chunks[i - 1].end + 1; - end = Math.min(start + chunkSize, fileSize); - } - - end = end == fileSize ? end : seekNextNewline(channel, end); - chunks[i] = new Chunk(start, end); - } - return chunks; - } - - private static long seekNextNewline(final FileChannel channel, final long end) throws IOException { - var bb = ByteBuffer.allocate(MAX_KEY_LENGTH); - channel.position(end).read(bb); - - for (int i = 0; i < bb.limit(); i++) { - if (bb.get(i) == NL) { - return end + i; - } - } - - throw new IllegalStateException("Couldn't find next newline"); - } - - private void process(final ByteBuffer bb) { - final var sample = new Sample(); - var isKey = true; - - for (long i = 0, sz = bb.limit(); i < sz; i++) { - - final byte b = bb.get(); - - if (b == SEMICOLON) { - isKey = false; - } - else if (b == NL) { - isKey = true; - addSample(sample); - sample.reset(); - } - else if (isKey) { - sample.pushKey(b); - } - else if (b == DOT) { - // skip - } - else if (b == MINUS) { - sample.sign = -1; + b = UNSAFE.getByte(i); + measurement = measurement * 10 + b - ZERO_DIGIT; + measurement = -measurement; + i += 2; } else { - sample.pushMeasurement(b); + measurement = b - ZERO_DIGIT; // D1 + b = UNSAFE.getByte(i); // dot or D2 + + if (b == DOT) { + measurement = measurement * 10 + UNSAFE.getByte(i + 1) - ZERO_DIGIT; // F + i += 3; + } + else { + measurement = measurement * 10 + b - ZERO_DIGIT; // D2 + measurement = measurement * 10 + UNSAFE.getByte(i + 2) - ZERO_DIGIT; // F + i += 4; // skip NL + } } + + final Stats stats = map.putStats(keyHash, keyAddress, keyLength); + stats.min = Math.min(stats.min, measurement); + stats.max = Math.max(stats.max, measurement); + stats.sum += measurement; + stats.count++; } - } - - private void addSample(final Sample sample) { - final Stats stats = map.computeIfAbsent(sample.keyHash, - k -> new Stats(new String(sample.keyBytes, 0, sample.keyLength, UTF_8))); - - final var val = sample.getMeasurement(); - - if (val < stats.min) - stats.min = val; - - if (val > stats.max) - stats.max = val; - - stats.sum += val; - stats.count++; - } - - void printStats() { - var sorted = new ArrayList<>(map.values()); - Collections.sort(sorted); - - int size = sorted.size(); - - System.out.print('{'); - - for (Stats stats : sorted) { - stats.print(System.out); - if (--size > 0) { - System.out.print(", "); - } - } - System.out.println('}'); + return map; } } private static class Stats implements Comparable { - private final String city; + private String key; + private final byte[] keyBytes; + private final int keyLength; + private final int keyHash; private int min = Integer.MAX_VALUE; private int max = Integer.MIN_VALUE; - private long sum; private int count; + private long sum; - private Stats(String city) { - this.city = city; + private Stats(long keyAddress, int keyLength, int keyHash) { + this.keyLength = keyLength; + this.keyBytes = new byte[keyLength]; + this.keyHash = keyHash; + + for (int i = 0; i < keyLength; i++) { + keyBytes[i] = UNSAFE.getByte(keyAddress++); + } + } + + String getKey() { + if (key == null) { + key = new String(keyBytes, 0, keyLength, UTF_8); + } + return key; } @Override public int compareTo(final Stats o) { - return city.compareTo(o.city); + return getKey().compareTo(o.getKey()); } void print(final PrintStream out) { - out.print(city); + out.print(key); out.print('='); out.print(round(min / 10f)); out.print('/'); @@ -210,32 +165,148 @@ public class CalculateAverage_armandino { } } - private static class Sample { - private final byte[] keyBytes = new byte[MAX_KEY_LENGTH]; - private int keyLength; - private int keyHash; - private int measurement; - private int sign = 1; + private static void print(final Collection sorted) { + int size = sorted.size(); + System.out.print('{'); + for (Stats stats : sorted) { + stats.print(System.out); + if (--size > 0) { + System.out.print(", "); + } + } + System.out.println('}'); + } - void pushKey(byte b) { - keyBytes[keyLength++] = b; - keyHash = 31 * keyHash + b; + private static Chunk[] split(final FileChannel channel) throws IOException { + final long fileSize = channel.size(); + long start = channel.map(READ_ONLY, 0, fileSize, Arena.global()).address(); + final long endAddress = start + fileSize; + if (fileSize < 10000) { + return new Chunk[]{ new Chunk(start, endAddress) }; } - void pushMeasurement(byte b) { - final int i = b - '0'; - measurement = measurement * 10 + i; + final long chunkSize = fileSize / NUM_CHUNKS; + final var chunks = new Chunk[NUM_CHUNKS]; + long end = start + chunkSize; + + for (int i = 0; i < NUM_CHUNKS; i++) { + if (i > 0) { + start = chunks[i - 1].end; + end = Math.min(start + chunkSize, endAddress); + } + if (end < endAddress) { + while (UNSAFE.getByte(end) != NL) { + end++; + } + end++; + } + chunks[i] = new Chunk(start, end); + } + return chunks; + } + + private record Chunk(long start, long end) { + } + + private static Unsafe getUnsafe() { + try { + Field unsafe = Unsafe.class.getDeclaredField("theUnsafe"); + unsafe.setAccessible(true); + return (Unsafe) unsafe.get(null); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static class SimpleMap { + private Stats[] table; + + SimpleMap(int initialCapacity) { + table = new Stats[initialCapacity]; } - int getMeasurement() { - return sign * measurement; + Stream stream() { + return Arrays.stream(table).filter(Objects::nonNull); } - void reset() { - keyHash = 0; - keyLength = 0; - measurement = 0; - sign = 1; + private void resize() { + var copy = new SimpleMap(table.length * 2); + for (Stats s : table) { + if (s != null) { + final int pos = (copy.table.length - 1) & s.keyHash; + int i = pos; + + if (copy.table[i] == null) { + copy.table[i] = s; + continue; + } + + while (i < copy.table.length && copy.table[i] != null) { + i++; + } + if (i == copy.table.length) { + i = pos; + while (i >= 0 && copy.table[i] != null) { + i--; + } + } + if (i < 0) { + // shouldn't happen because put() is called after increasing size + throw new IllegalStateException("table is full"); + } + copy.table[i] = s; + } + } + table = copy.table; + } + + Stats putStats(final int keyHash, final long keyAddress, final int keyLength) { + final int pos = (table.length - 1) & keyHash; + + Stats stats = table[pos]; + if (stats == null) + return createAt(table, keyAddress, keyLength, keyHash, pos); + if (stats.keyHash == keyHash && keysEqual(stats, keyAddress, keyLength)) + return stats; + + int i = pos; + while (++i < table.length) { + stats = table[i]; + if (stats == null) + return createAt(table, keyAddress, keyLength, keyHash, i); + if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength)) + return stats; + } + + i = pos; + while (i-- > 0) { + stats = table[i]; + if (stats == null) + return createAt(table, keyAddress, keyLength, keyHash, i); + if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength)) + return stats; + } + resize(); + return putStats(keyHash, keyAddress, keyLength); + } + + private boolean keysEqual(Stats stats, long keyAddress, final int keyLength) { + if (stats.keyLength != keyLength) { + return false; + } + for (int i = 0; i < keyLength; i++) { + if (stats.keyBytes[i] != UNSAFE.getByte(keyAddress++)) { + return false; + } + } + return true; + } + + private static Stats createAt(Stats[] table, long keyAddress, int keyLength, int key, int i) { + Stats stats = new Stats(keyAddress, keyLength, key); + table[i] = stats; + return stats; } } }