update richardstartin submission (#325)

This commit is contained in:
Richard Startin 2024-01-12 08:06:09 +00:00 committed by GitHub
parent 6181996678
commit 37a50cb2af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -26,11 +26,10 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask; import java.util.concurrent.RecursiveAction;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.IntSupplier; import java.util.function.Consumer;
public class CalculateAverage_richardstartin { public class CalculateAverage_richardstartin {
@ -46,58 +45,64 @@ public class CalculateAverage_richardstartin {
} }
} }
static String bufferToString(ByteBuffer slice) { record Slot(byte[] key, int[] aggregates) {
byte[] bytes = new byte[slice.limit()]; private static final int WIDTH = 8;
slice.get(0, bytes); private static int[] newAggregates(int stripes) {
return new String(bytes, StandardCharsets.UTF_8); var aggregates = new int[stripes * WIDTH];
} for (int i = 0; i < aggregates.length; i += WIDTH) {
aggregates[i] = Integer.MAX_VALUE;
static double parseTemperature(ByteBuffer slice) { aggregates[i + 1] = Integer.MIN_VALUE;
// credit: adapted from spullara's submission
int value = 0;
int negative = 1;
int i = 0;
while (i != slice.limit()) {
byte b = slice.get(i++);
switch (b) {
case '-':
negative = -1;
case '.':
break;
default:
value = 10 * value + (b - '0');
} }
return aggregates;
}
Slot(byte[] key, int stripes) {
this(key, newAggregates(stripes));
} }
value *= negative;
return value / 10.0;
}
@FunctionalInterface void update(int stripe, int value) {
interface IndexedStringConsumer { int i = stripe * WIDTH;
void accept(String value, int index); aggregates[i] = Math.min(value, aggregates[i]);
aggregates[i + 1] = Math.max(value, aggregates[i + 1]);
aggregates[i + 2] += value;
aggregates[i + 3]++;
}
public ResultRow toResultRow() {
int min = Integer.MAX_VALUE;
int max = Integer.MIN_VALUE;
int sum = 0;
int count = 0;
for (int i = 0; i < aggregates.length; i += WIDTH) {
min = Math.min(min, aggregates[i]);
max = Math.max(max, aggregates[i + 1]);
sum += aggregates[i + 2];
count += aggregates[i + 3];
}
return new ResultRow(min * 0.1, 0.1 * sum / count, max * 0.1);
}
public String toKey() {
return new String(key, StandardCharsets.UTF_8);
}
} }
/** Maps text to an integer encoding. Adapted from async-profiler. */ /** Maps text to an integer encoding. Adapted from async-profiler. */
public static class Dictionary { public static class Dictionary {
private static final int ROW_BITS = 12; private static final int ROW_BITS = 11;
private static final int ROWS = (1 << ROW_BITS); private static final int ROWS = (1 << ROW_BITS);
private static final int CELLS = 3; private static final int TABLE_CAPACITY = ROWS;
private static final int TABLE_CAPACITY = (ROWS * CELLS);
private final Table table = new Table(nextBaseIndex()); private final Table table = new Table(this, nextBaseIndex());
private static final AtomicIntegerFieldUpdater<Dictionary> BASE_INDEX_UPDATER = AtomicIntegerFieldUpdater.newUpdater(Dictionary.class, "baseIndex"); private static final AtomicIntegerFieldUpdater<Dictionary> BASE_INDEX_UPDATER = AtomicIntegerFieldUpdater.newUpdater(Dictionary.class, "baseIndex");
volatile int baseIndex; volatile int baseIndex;
private void forEach(Table table, IndexedStringConsumer consumer) { private void forEach(Table table, Consumer<Slot> consumer) {
for (int i = 0; i < ROWS; i++) { for (var row : table.rows) {
Row row = table.rows[i]; var slot = row.slot;
for (int j = 0; j < CELLS; j++) { if (slot != null) {
var slice = row.keys.get(j); consumer.accept(slot);
if (slice != null) {
consumer.accept(bufferToString(slice), table.index(i, j));
}
} }
if (row.next != null) { if (row.next != null) {
forEach(row.next, consumer); forEach(row.next, consumer);
@ -105,33 +110,32 @@ public class CalculateAverage_richardstartin {
} }
} }
public void forEach(IndexedStringConsumer consumer) { public void forEach(Consumer<Slot> consumer) {
forEach(this.table, consumer); forEach(this.table, consumer);
} }
public int encode(int hash, ByteBuffer slice) { public Slot lookup(int hash, byte[] key, int length, int stripes) {
Table table = this.table; Table table = this.table;
while (true) { while (true) {
int rowIndex = Math.abs(hash) % ROWS; int rowIndex = Math.abs(hash) % ROWS;
Row row = table.rows[rowIndex]; Row row = table.rows[rowIndex];
for (int c = 0; c < CELLS; c++) { var storedSlot = row.slot;
ByteBuffer storedKey = row.keys.get(c); if (storedSlot == null) {
if (storedKey == null) { Slot slot = new Slot(Arrays.copyOf(key, length), stripes);
if (row.keys.compareAndSet(c, null, slice)) { if (row.compareAndSet(null, slot)) {
return table.index(rowIndex, c); return slot;
}
else {
storedKey = row.keys.get(c);
if (slice.equals(storedKey)) {
return table.index(rowIndex, c);
}
}
} }
else if (slice.equals(storedKey)) { else {
return table.index(rowIndex, c); storedSlot = row.slot;
if (Arrays.equals(key, 0, length, storedSlot.key, 0, storedSlot.key.length)) {
return storedSlot;
}
} }
} }
table = row.getOrCreateNextTable(this::nextBaseIndex); else if (Arrays.equals(key, 0, length, storedSlot.key, 0, storedSlot.key.length)) {
return storedSlot;
}
table = row.getOrCreateNextTable();
hash = Integer.rotateRight(hash, ROW_BITS); hash = Integer.rotateRight(hash, ROW_BITS);
} }
} }
@ -143,13 +147,19 @@ public class CalculateAverage_richardstartin {
private static final class Row { private static final class Row {
private static final AtomicReferenceFieldUpdater<Row, Table> NEXT_TABLE_UPDATER = AtomicReferenceFieldUpdater.newUpdater(Row.class, Table.class, "next"); private static final AtomicReferenceFieldUpdater<Row, Table> NEXT_TABLE_UPDATER = AtomicReferenceFieldUpdater.newUpdater(Row.class, Table.class, "next");
private final AtomicReferenceArray<ByteBuffer> keys = new AtomicReferenceArray<>(CELLS); private static final AtomicReferenceFieldUpdater<Row, Slot> SLOT_UPDATER = AtomicReferenceFieldUpdater.newUpdater(Row.class, Slot.class, "slot");
private volatile Slot slot = null;
private final Dictionary dictionary;
volatile Table next; volatile Table next;
public Table getOrCreateNextTable(IntSupplier baseIndexSupplier) { private Row(Dictionary dictionary) {
this.dictionary = dictionary;
}
public Table getOrCreateNextTable() {
Table next = this.next; Table next = this.next;
if (next == null) { if (next == null) {
Table newTable = new Table(baseIndexSupplier.getAsInt()); Table newTable = new Table(dictionary, dictionary.nextBaseIndex());
if (NEXT_TABLE_UPDATER.compareAndSet(this, null, newTable)) { if (NEXT_TABLE_UPDATER.compareAndSet(this, null, newTable)) {
next = newTable; next = newTable;
} }
@ -159,6 +169,10 @@ public class CalculateAverage_richardstartin {
} }
return next; return next;
} }
public boolean compareAndSet(Slot expected, Slot newSlot) {
return SLOT_UPDATER.compareAndSet(this, expected, newSlot);
}
} }
private static final class Table { private static final class Table {
@ -166,14 +180,10 @@ public class CalculateAverage_richardstartin {
final Row[] rows; final Row[] rows;
final int baseIndex; final int baseIndex;
private Table(int baseIndex) { private Table(Dictionary dictionary, int baseIndex) {
this.baseIndex = baseIndex; this.baseIndex = baseIndex;
this.rows = new Row[ROWS]; this.rows = new Row[ROWS];
Arrays.setAll(rows, i -> new Row()); Arrays.setAll(rows, i -> new Row(dictionary));
}
int index(int row, int col) {
return baseIndex + (col << ROW_BITS) + row;
} }
} }
} }
@ -182,23 +192,11 @@ public class CalculateAverage_richardstartin {
return 0x101010101010101L * repeat; return 0x101010101010101L * repeat;
} }
private static long compilePattern(char delimiter) {
return compilePattern(delimiter & 0xFFL);
}
private static long compilePattern(byte delimiter) { private static long compilePattern(byte delimiter) {
return compilePattern(delimiter & 0xFFL); return compilePattern(delimiter & 0xFFL);
} }
private static final long NEW_LINE = compilePattern((byte) '\n'); private static final long DELIMITER = compilePattern((byte) ';');
private static final long DELIMITER = compilePattern(';');
private static int firstInstance(long word, long pattern) {
long input = word ^ pattern;
long tmp = (input & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL;
tmp = ~(tmp | input | 0x7F7F7F7F7F7F7F7FL);
return Long.numberOfTrailingZeros(tmp) >>> 3;
}
private static int findLastNewLine(ByteBuffer buffer) { private static int findLastNewLine(ByteBuffer buffer) {
return findLastNewLine(buffer, buffer.limit() - 1); return findLastNewLine(buffer, buffer.limit() - 1);
@ -213,16 +211,19 @@ public class CalculateAverage_richardstartin {
return 0; return 0;
} }
private static int findIndexOf(ByteBuffer buffer, int offset, long pattern) { private static int findIndexOf(ByteBuffer buffer, int limit, int offset, long pattern) {
int i = offset; int i = offset;
for (; i + Long.BYTES < buffer.limit(); i += Long.BYTES) { for (; i < limit - Long.BYTES + 1; i += Long.BYTES) {
int index = firstInstance(buffer.getLong(i), pattern); long word = buffer.getLong(i);
if (index != Long.BYTES) { long input = word ^ pattern;
return i + index; long tmp = (input & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL;
tmp |= input | 0x7F7F7F7F7F7F7F7FL;
if (tmp != -1L) {
return i + (Long.numberOfTrailingZeros(~tmp) >>> 3);
} }
} }
byte b = (byte) (pattern & 0xFF); byte b = (byte) (pattern & 0xFF);
for (; i < buffer.limit(); i++) { for (; i < limit; i++) {
if (buffer.get(i) == b) { if (buffer.get(i) == b) {
return i; return i;
} }
@ -230,56 +231,15 @@ public class CalculateAverage_richardstartin {
return buffer.limit(); return buffer.limit();
} }
static class Page { private static int hash(byte[] bytes, int limit) {
int hash = 1;
static final int PAGE_SIZE = 1024; for (int i = 0; i < limit; i++) {
static final int PAGE_SHIFT = Integer.numberOfTrailingZeros(PAGE_SIZE); hash += hash * 129 + bytes[i];
static final int PAGE_MASK = PAGE_SIZE - 1;
private static final double[] PAGE_PROTOTYPE = new double[PAGE_SIZE * 4];
static {
for (int i = 0; i < PAGE_SIZE * 4; i += 4) {
PAGE_PROTOTYPE[i + 1] = Double.POSITIVE_INFINITY;
PAGE_PROTOTYPE[i + 2] = Double.NEGATIVE_INFINITY;
}
}
private static double[] newPage() {
return Arrays.copyOf(PAGE_PROTOTYPE, PAGE_PROTOTYPE.length);
}
static void update(List<double[]> pages, int position, double value) {
// find the page
int pageIndex = position >>> PAGE_SHIFT;
if (pageIndex >= pages.size()) {
for (int i = pages.size(); i <= pageIndex; i++) {
pages.add(null);
}
}
double[] page = pages.get(pageIndex);
if (page == null) {
page = newPage();
pages.set(pageIndex, page);
}
// update local aggregates
page[(position & PAGE_MASK) * 4]++; // count
page[(position & PAGE_MASK) * 4 + 1] = Math.min(page[(position & PAGE_MASK) * 4 + 1], value); // min
page[(position & PAGE_MASK) * 4 + 2] = Math.max(page[(position & PAGE_MASK) * 4 + 2], value); // min
page[(position & PAGE_MASK) * 4 + 3] += value; // sum
}
static ResultRow toResultRow(List<double[]> pages, int position) {
double[] page = pages.get(position >>> PAGE_SHIFT);
double count = page[(position & PAGE_MASK) * 4];
double min = page[(position & PAGE_MASK) * 4 + 1];
double max = page[(position & PAGE_MASK) * 4 + 2];
double sum = page[(position & PAGE_MASK) * 4 + 3];
return new ResultRow(min, sum / count, max);
} }
return hash;
} }
private static class AggregationTask extends RecursiveTask<List<double[]>> { private static class AggregationTask extends RecursiveAction {
private final Dictionary dictionary; private final Dictionary dictionary;
private final List<ByteBuffer> slices; private final List<ByteBuffer> slices;
@ -297,70 +257,52 @@ public class CalculateAverage_richardstartin {
this.max = max; this.max = max;
} }
private void computeSlice(ByteBuffer slice, List<double[]> pages) { private void computeSlice(int stripe) {
for (int offset = 0; offset < slice.limit();) { var slice = slices.get(stripe);
int nextSeparator = findIndexOf(slice, offset, DELIMITER); int end = slice.limit();
ByteBuffer key = slice.slice(offset, nextSeparator - offset).order(ByteOrder.LITTLE_ENDIAN); byte[] tmp = new byte[128];
// find the global dictionary code to aggregate, for (int offset = 0; offset < end;) {
// making this code global allows easy merging int delimiter = findIndexOf(slice, end, offset, DELIMITER);
int dictId = dictionary.encode(key.hashCode(), key); int value = 0;
int sign = 1;
offset = nextSeparator + 1; byte b;
int newLine = findIndexOf(slice, offset, NEW_LINE); int i = delimiter + 1;
// parse the double while (i != end && (b = slice.get(i++)) != '\n') {
double d = parseTemperature(slice.slice(offset, newLine - offset)); if (b != '.') {
if (b == '-') {
Page.update(pages, dictId, d); sign = -1;
}
offset = newLine + 1; else {
} value = 10 * value + (b - '0');
} }
private static void merge(List<double[]> contribution, List<double[]> aggregate) {
for (int i = aggregate.size(); i < contribution.size(); i++) {
aggregate.add(null);
}
for (int i = 0; i < contribution.size(); i++) {
if (aggregate.get(i) == null) {
aggregate.set(i, contribution.get(i));
}
else if (contribution.get(i) != null) {
double[] to = aggregate.get(i);
double[] from = contribution.get(i);
// todo won't vectorise - consider separating aggregates into distinct regions and apply
// loop fission (if this shows up in the profile)
for (int j = 0; j < to.length; j += 4) {
to[j] += from[j];
to[j + 1] = Math.min(to[j + 1], from[j + 1]);
to[j + 2] = Math.max(to[j + 2], from[j + 2]);
to[j + 3] += from[j + 3];
} }
} }
value *= sign;
int length = delimiter - offset;
slice.get(offset, tmp, 0, length);
dictionary.lookup(hash(tmp, length), tmp, length, slices.size()).update(stripe, value);
offset = i;
} }
} }
@Override @Override
protected List<double[]> compute() { protected void compute() {
if (min == max) { if (min == max) {
var pages = new ArrayList<double[]>(); computeSlice(min);
var slice = slices.get(min);
computeSlice(slice, pages);
return pages;
} }
else { else {
int mid = (min + max) / 2; int mid = (min + max) / 2;
var low = new AggregationTask(dictionary, slices, min, mid); var low = new AggregationTask(dictionary, slices, min, mid);
var high = new AggregationTask(dictionary, slices, mid + 1, max); var high = new AggregationTask(dictionary, slices, mid + 1, max);
var fork = high.fork(); var fork = high.fork();
var partial = low.compute(); low.compute();
merge(fork.join(), partial); fork.join();
return partial;
} }
} }
} }
public static void main(String[] args) throws IOException { public static void main(String[] args) throws IOException {
int maxChunkSize = 250 << 20; // 250MiB int maxChunkSize = 10 << 20; // 10MiB
try (var raf = new RandomAccessFile(FILE, "r"); try (var raf = new RandomAccessFile(FILE, "r");
var channel = raf.getChannel()) { var channel = raf.getChannel()) {
long size = channel.size(); long size = channel.size();
@ -394,12 +336,13 @@ public class CalculateAverage_richardstartin {
} }
} }
var fjp = new ForkJoinPool(Runtime.getRuntime().availableProcessors()); try (var fjp = new ForkJoinPool(Runtime.getRuntime().availableProcessors())) {
Dictionary dictionary = new Dictionary(); Dictionary dictionary = new Dictionary();
var aggregates = fjp.submit(new AggregationTask(dictionary, slices)).join(); fjp.submit(new AggregationTask(dictionary, slices)).join();
var map = new TreeMap<String, ResultRow>(); var map = new TreeMap<String, ResultRow>();
dictionary.forEach((key, index) -> map.put(key, Page.toResultRow(aggregates, index))); dictionary.forEach(slot -> map.put(slot.toKey(), slot.toResultRow()));
System.out.println(map); System.out.println(map);
}
} }
} }
} }