From 37a50cb2afd7d9fe6032988fdce692b25ac79b54 Mon Sep 17 00:00:00 2001 From: Richard Startin Date: Fri, 12 Jan 2024 08:06:09 +0000 Subject: [PATCH] update richardstartin submission (#325) --- .../CalculateAverage_richardstartin.java | 313 +++++++----------- 1 file changed, 128 insertions(+), 185 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_richardstartin.java b/src/main/java/dev/morling/onebrc/CalculateAverage_richardstartin.java index b619c25..2a445bf 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_richardstartin.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_richardstartin.java @@ -26,11 +26,10 @@ import java.util.Arrays; import java.util.List; import java.util.TreeMap; import java.util.concurrent.ForkJoinPool; -import java.util.concurrent.RecursiveTask; +import java.util.concurrent.RecursiveAction; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.IntSupplier; +import java.util.function.Consumer; public class CalculateAverage_richardstartin { @@ -46,58 +45,64 @@ public class CalculateAverage_richardstartin { } } - static String bufferToString(ByteBuffer slice) { - byte[] bytes = new byte[slice.limit()]; - slice.get(0, bytes); - return new String(bytes, StandardCharsets.UTF_8); - } - - static double parseTemperature(ByteBuffer slice) { - // credit: adapted from spullara's submission - int value = 0; - int negative = 1; - int i = 0; - while (i != slice.limit()) { - byte b = slice.get(i++); - switch (b) { - case '-': - negative = -1; - case '.': - break; - default: - value = 10 * value + (b - '0'); + record Slot(byte[] key, int[] aggregates) { + private static final int WIDTH = 8; + private static int[] newAggregates(int stripes) { + var aggregates = new int[stripes * WIDTH]; + for (int i = 0; i < aggregates.length; i += WIDTH) { + aggregates[i] = Integer.MAX_VALUE; + aggregates[i + 1] = Integer.MIN_VALUE; } + return aggregates; + } + Slot(byte[] key, int stripes) { + this(key, newAggregates(stripes)); } - value *= negative; - return value / 10.0; - } - @FunctionalInterface - interface IndexedStringConsumer { - void accept(String value, int index); + void update(int stripe, int value) { + int i = stripe * WIDTH; + aggregates[i] = Math.min(value, aggregates[i]); + aggregates[i + 1] = Math.max(value, aggregates[i + 1]); + aggregates[i + 2] += value; + aggregates[i + 3]++; + } + + public ResultRow toResultRow() { + int min = Integer.MAX_VALUE; + int max = Integer.MIN_VALUE; + int sum = 0; + int count = 0; + for (int i = 0; i < aggregates.length; i += WIDTH) { + min = Math.min(min, aggregates[i]); + max = Math.max(max, aggregates[i + 1]); + sum += aggregates[i + 2]; + count += aggregates[i + 3]; + } + return new ResultRow(min * 0.1, 0.1 * sum / count, max * 0.1); + } + + public String toKey() { + return new String(key, StandardCharsets.UTF_8); + } } /** Maps text to an integer encoding. Adapted from async-profiler. */ public static class Dictionary { - private static final int ROW_BITS = 12; + private static final int ROW_BITS = 11; private static final int ROWS = (1 << ROW_BITS); - private static final int CELLS = 3; - private static final int TABLE_CAPACITY = (ROWS * CELLS); + private static final int TABLE_CAPACITY = ROWS; - private final Table table = new Table(nextBaseIndex()); + private final Table table = new Table(this, nextBaseIndex()); private static final AtomicIntegerFieldUpdater BASE_INDEX_UPDATER = AtomicIntegerFieldUpdater.newUpdater(Dictionary.class, "baseIndex"); volatile int baseIndex; - private void forEach(Table table, IndexedStringConsumer consumer) { - for (int i = 0; i < ROWS; i++) { - Row row = table.rows[i]; - for (int j = 0; j < CELLS; j++) { - var slice = row.keys.get(j); - if (slice != null) { - consumer.accept(bufferToString(slice), table.index(i, j)); - } + private void forEach(Table table, Consumer consumer) { + for (var row : table.rows) { + var slot = row.slot; + if (slot != null) { + consumer.accept(slot); } if (row.next != null) { forEach(row.next, consumer); @@ -105,33 +110,32 @@ public class CalculateAverage_richardstartin { } } - public void forEach(IndexedStringConsumer consumer) { + public void forEach(Consumer consumer) { forEach(this.table, consumer); } - public int encode(int hash, ByteBuffer slice) { + public Slot lookup(int hash, byte[] key, int length, int stripes) { Table table = this.table; while (true) { int rowIndex = Math.abs(hash) % ROWS; Row row = table.rows[rowIndex]; - for (int c = 0; c < CELLS; c++) { - ByteBuffer storedKey = row.keys.get(c); - if (storedKey == null) { - if (row.keys.compareAndSet(c, null, slice)) { - return table.index(rowIndex, c); - } - else { - storedKey = row.keys.get(c); - if (slice.equals(storedKey)) { - return table.index(rowIndex, c); - } - } + var storedSlot = row.slot; + if (storedSlot == null) { + Slot slot = new Slot(Arrays.copyOf(key, length), stripes); + if (row.compareAndSet(null, slot)) { + return slot; } - else if (slice.equals(storedKey)) { - return table.index(rowIndex, c); + else { + storedSlot = row.slot; + if (Arrays.equals(key, 0, length, storedSlot.key, 0, storedSlot.key.length)) { + return storedSlot; + } } } - table = row.getOrCreateNextTable(this::nextBaseIndex); + else if (Arrays.equals(key, 0, length, storedSlot.key, 0, storedSlot.key.length)) { + return storedSlot; + } + table = row.getOrCreateNextTable(); hash = Integer.rotateRight(hash, ROW_BITS); } } @@ -143,13 +147,19 @@ public class CalculateAverage_richardstartin { private static final class Row { private static final AtomicReferenceFieldUpdater NEXT_TABLE_UPDATER = AtomicReferenceFieldUpdater.newUpdater(Row.class, Table.class, "next"); - private final AtomicReferenceArray keys = new AtomicReferenceArray<>(CELLS); + private static final AtomicReferenceFieldUpdater SLOT_UPDATER = AtomicReferenceFieldUpdater.newUpdater(Row.class, Slot.class, "slot"); + private volatile Slot slot = null; + private final Dictionary dictionary; volatile Table next; - public Table getOrCreateNextTable(IntSupplier baseIndexSupplier) { + private Row(Dictionary dictionary) { + this.dictionary = dictionary; + } + + public Table getOrCreateNextTable() { Table next = this.next; if (next == null) { - Table newTable = new Table(baseIndexSupplier.getAsInt()); + Table newTable = new Table(dictionary, dictionary.nextBaseIndex()); if (NEXT_TABLE_UPDATER.compareAndSet(this, null, newTable)) { next = newTable; } @@ -159,6 +169,10 @@ public class CalculateAverage_richardstartin { } return next; } + + public boolean compareAndSet(Slot expected, Slot newSlot) { + return SLOT_UPDATER.compareAndSet(this, expected, newSlot); + } } private static final class Table { @@ -166,14 +180,10 @@ public class CalculateAverage_richardstartin { final Row[] rows; final int baseIndex; - private Table(int baseIndex) { + private Table(Dictionary dictionary, int baseIndex) { this.baseIndex = baseIndex; this.rows = new Row[ROWS]; - Arrays.setAll(rows, i -> new Row()); - } - - int index(int row, int col) { - return baseIndex + (col << ROW_BITS) + row; + Arrays.setAll(rows, i -> new Row(dictionary)); } } } @@ -182,23 +192,11 @@ public class CalculateAverage_richardstartin { return 0x101010101010101L * repeat; } - private static long compilePattern(char delimiter) { - return compilePattern(delimiter & 0xFFL); - } - private static long compilePattern(byte delimiter) { return compilePattern(delimiter & 0xFFL); } - private static final long NEW_LINE = compilePattern((byte) '\n'); - private static final long DELIMITER = compilePattern(';'); - - private static int firstInstance(long word, long pattern) { - long input = word ^ pattern; - long tmp = (input & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL; - tmp = ~(tmp | input | 0x7F7F7F7F7F7F7F7FL); - return Long.numberOfTrailingZeros(tmp) >>> 3; - } + private static final long DELIMITER = compilePattern((byte) ';'); private static int findLastNewLine(ByteBuffer buffer) { return findLastNewLine(buffer, buffer.limit() - 1); @@ -213,16 +211,19 @@ public class CalculateAverage_richardstartin { return 0; } - private static int findIndexOf(ByteBuffer buffer, int offset, long pattern) { + private static int findIndexOf(ByteBuffer buffer, int limit, int offset, long pattern) { int i = offset; - for (; i + Long.BYTES < buffer.limit(); i += Long.BYTES) { - int index = firstInstance(buffer.getLong(i), pattern); - if (index != Long.BYTES) { - return i + index; + for (; i < limit - Long.BYTES + 1; i += Long.BYTES) { + long word = buffer.getLong(i); + long input = word ^ pattern; + long tmp = (input & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL; + tmp |= input | 0x7F7F7F7F7F7F7F7FL; + if (tmp != -1L) { + return i + (Long.numberOfTrailingZeros(~tmp) >>> 3); } } byte b = (byte) (pattern & 0xFF); - for (; i < buffer.limit(); i++) { + for (; i < limit; i++) { if (buffer.get(i) == b) { return i; } @@ -230,56 +231,15 @@ public class CalculateAverage_richardstartin { return buffer.limit(); } - static class Page { - - static final int PAGE_SIZE = 1024; - static final int PAGE_SHIFT = Integer.numberOfTrailingZeros(PAGE_SIZE); - static final int PAGE_MASK = PAGE_SIZE - 1; - private static final double[] PAGE_PROTOTYPE = new double[PAGE_SIZE * 4]; - - static { - for (int i = 0; i < PAGE_SIZE * 4; i += 4) { - PAGE_PROTOTYPE[i + 1] = Double.POSITIVE_INFINITY; - PAGE_PROTOTYPE[i + 2] = Double.NEGATIVE_INFINITY; - } - } - - private static double[] newPage() { - return Arrays.copyOf(PAGE_PROTOTYPE, PAGE_PROTOTYPE.length); - } - - static void update(List pages, int position, double value) { - // find the page - int pageIndex = position >>> PAGE_SHIFT; - if (pageIndex >= pages.size()) { - for (int i = pages.size(); i <= pageIndex; i++) { - pages.add(null); - } - } - double[] page = pages.get(pageIndex); - if (page == null) { - page = newPage(); - pages.set(pageIndex, page); - } - - // update local aggregates - page[(position & PAGE_MASK) * 4]++; // count - page[(position & PAGE_MASK) * 4 + 1] = Math.min(page[(position & PAGE_MASK) * 4 + 1], value); // min - page[(position & PAGE_MASK) * 4 + 2] = Math.max(page[(position & PAGE_MASK) * 4 + 2], value); // min - page[(position & PAGE_MASK) * 4 + 3] += value; // sum - } - - static ResultRow toResultRow(List pages, int position) { - double[] page = pages.get(position >>> PAGE_SHIFT); - double count = page[(position & PAGE_MASK) * 4]; - double min = page[(position & PAGE_MASK) * 4 + 1]; - double max = page[(position & PAGE_MASK) * 4 + 2]; - double sum = page[(position & PAGE_MASK) * 4 + 3]; - return new ResultRow(min, sum / count, max); + private static int hash(byte[] bytes, int limit) { + int hash = 1; + for (int i = 0; i < limit; i++) { + hash += hash * 129 + bytes[i]; } + return hash; } - private static class AggregationTask extends RecursiveTask> { + private static class AggregationTask extends RecursiveAction { private final Dictionary dictionary; private final List slices; @@ -297,70 +257,52 @@ public class CalculateAverage_richardstartin { this.max = max; } - private void computeSlice(ByteBuffer slice, List pages) { - for (int offset = 0; offset < slice.limit();) { - int nextSeparator = findIndexOf(slice, offset, DELIMITER); - ByteBuffer key = slice.slice(offset, nextSeparator - offset).order(ByteOrder.LITTLE_ENDIAN); - // find the global dictionary code to aggregate, - // making this code global allows easy merging - int dictId = dictionary.encode(key.hashCode(), key); - - offset = nextSeparator + 1; - int newLine = findIndexOf(slice, offset, NEW_LINE); - // parse the double - double d = parseTemperature(slice.slice(offset, newLine - offset)); - - Page.update(pages, dictId, d); - - offset = newLine + 1; - } - } - - private static void merge(List contribution, List aggregate) { - for (int i = aggregate.size(); i < contribution.size(); i++) { - aggregate.add(null); - } - for (int i = 0; i < contribution.size(); i++) { - if (aggregate.get(i) == null) { - aggregate.set(i, contribution.get(i)); - } - else if (contribution.get(i) != null) { - double[] to = aggregate.get(i); - double[] from = contribution.get(i); - // todo won't vectorise - consider separating aggregates into distinct regions and apply - // loop fission (if this shows up in the profile) - for (int j = 0; j < to.length; j += 4) { - to[j] += from[j]; - to[j + 1] = Math.min(to[j + 1], from[j + 1]); - to[j + 2] = Math.max(to[j + 2], from[j + 2]); - to[j + 3] += from[j + 3]; + private void computeSlice(int stripe) { + var slice = slices.get(stripe); + int end = slice.limit(); + byte[] tmp = new byte[128]; + for (int offset = 0; offset < end;) { + int delimiter = findIndexOf(slice, end, offset, DELIMITER); + int value = 0; + int sign = 1; + byte b; + int i = delimiter + 1; + while (i != end && (b = slice.get(i++)) != '\n') { + if (b != '.') { + if (b == '-') { + sign = -1; + } + else { + value = 10 * value + (b - '0'); + } } } + value *= sign; + int length = delimiter - offset; + slice.get(offset, tmp, 0, length); + dictionary.lookup(hash(tmp, length), tmp, length, slices.size()).update(stripe, value); + offset = i; } } @Override - protected List compute() { + protected void compute() { if (min == max) { - var pages = new ArrayList(); - var slice = slices.get(min); - computeSlice(slice, pages); - return pages; + computeSlice(min); } else { int mid = (min + max) / 2; var low = new AggregationTask(dictionary, slices, min, mid); var high = new AggregationTask(dictionary, slices, mid + 1, max); var fork = high.fork(); - var partial = low.compute(); - merge(fork.join(), partial); - return partial; + low.compute(); + fork.join(); } } } public static void main(String[] args) throws IOException { - int maxChunkSize = 250 << 20; // 250MiB + int maxChunkSize = 10 << 20; // 10MiB try (var raf = new RandomAccessFile(FILE, "r"); var channel = raf.getChannel()) { long size = channel.size(); @@ -394,12 +336,13 @@ public class CalculateAverage_richardstartin { } } - var fjp = new ForkJoinPool(Runtime.getRuntime().availableProcessors()); - Dictionary dictionary = new Dictionary(); - var aggregates = fjp.submit(new AggregationTask(dictionary, slices)).join(); - var map = new TreeMap(); - dictionary.forEach((key, index) -> map.put(key, Page.toResultRow(aggregates, index))); - System.out.println(map); + try (var fjp = new ForkJoinPool(Runtime.getRuntime().availableProcessors())) { + Dictionary dictionary = new Dictionary(); + fjp.submit(new AggregationTask(dictionary, slices)).join(); + var map = new TreeMap(); + dictionary.forEach(slot -> map.put(slot.toKey(), slot.toResultRow())); + System.out.println(map); + } } } } \ No newline at end of file