From 8602a355048a8d34d95611e7f80c00e1ea9b853a Mon Sep 17 00:00:00 2001 From: Artsiom Korzun <72259616+artsiomkorzun@users.noreply.github.com> Date: Thu, 11 Jan 2024 09:00:24 +0100 Subject: [PATCH] improved artsiomkorzun solution (#176) improved artsiomkorzun solution improved artsiomkorzun solution Co-authored-by: Artsiom Korzun --- calculate_average_artsiomkorzun.sh | 2 +- .../CalculateAverage_artsiomkorzun.java | 452 ++++++++++-------- 2 files changed, 254 insertions(+), 200 deletions(-) diff --git a/calculate_average_artsiomkorzun.sh b/calculate_average_artsiomkorzun.sh index 805330e..eaf050c 100755 --- a/calculate_average_artsiomkorzun.sh +++ b/calculate_average_artsiomkorzun.sh @@ -15,5 +15,5 @@ # limitations under the License. # -JAVA_OPTS="-XX:+UseParallelGC" +JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_artsiomkorzun diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java index 516a6ab..c9b7144 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java @@ -15,28 +15,46 @@ */ package dev.morling.onebrc; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.MappedByteBuffer; +import sun.misc.Unsafe; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.reflect.Field; +import java.nio.ByteOrder; import java.nio.channels.FileChannel; -import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.Arrays; -import java.util.Comparator; +import java.util.Map; +import java.util.TreeMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; public class CalculateAverage_artsiomkorzun { private static final Path FILE = Path.of("./measurements.txt"); - private static final long FILE_SIZE = size(FILE); + private static final MemorySegment MAPPED_FILE = map(FILE); private static final int PARALLELISM = Runtime.getRuntime().availableProcessors(); private static final int SEGMENT_SIZE = 16 * 1024 * 1024; - private static final int SEGMENT_COUNT = (int) ((FILE_SIZE + SEGMENT_SIZE - 1) / SEGMENT_SIZE); + private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE); private static final int SEGMENT_OVERLAP = 1024; + private static final long COMMA_PATTERN = pattern(';'); + private static final long DOT_BITS = 0x10101000; + private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); + + private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder(); + private static final Unsafe UNSAFE; + + static { + try { + Field unsafe = Unsafe.class.getDeclaredField("theUnsafe"); + unsafe.setAccessible(true); + UNSAFE = (Unsafe) unsafe.get(Unsafe.class); + } + catch (Throwable e) { + throw new RuntimeException(e); + } + } public static void main(String[] args) throws Exception { // for (int i = 0; i < 10; i++) { @@ -63,196 +81,231 @@ public class CalculateAverage_artsiomkorzun { aggregators[i].join(); } - Aggregates aggregates = result.get(); - aggregates.sort(); - - print(aggregates); + Map aggregates = result.get().aggregate(); + System.out.println(text(aggregates)); } - private static void print(Aggregates aggregates) { - StringBuilder builder = new StringBuilder(aggregates.size() * 15 + 32); - builder.append("{"); - aggregates.visit(aggregate -> { - if (builder.length() > 1) { - builder.append(", "); - } - - builder.append(aggregate); - }); - builder.append("}"); - System.out.println(builder); - } - - private static long size(Path file) { - try { - return Files.size(file); + private static MemorySegment map(Path file) { + try (FileChannel channel = FileChannel.open(file, StandardOpenOption.READ)) { + long size = channel.size(); + return channel.map(FileChannel.MapMode.READ_ONLY, 0, size, Arena.global()); } - catch (IOException e) { + catch (Throwable e) { throw new RuntimeException(e); } } - private static class Row { - final byte[] station = new byte[256]; - int length; - int hash; - int temperature; - - @Override - public String toString() { - return new String(station, 0, length) + ":" + temperature; - } + private static long pattern(char c) { + long b = c & 0xFFL; + return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56); } - private static class Aggregate implements Comparable { - final byte[] station; - final int hash; - int min; - int max; - long sum; - int count; + private static long getLongBigEndian(long address) { + long value = UNSAFE.getLong(address); - public Aggregate(Row row) { - this.station = Arrays.copyOf(row.station, row.length); - this.hash = row.hash; - this.min = row.temperature; - this.max = row.temperature; - this.sum = row.temperature; - this.count = 1; + if (BYTE_ORDER == ByteOrder.LITTLE_ENDIAN) { + value = Long.reverseBytes(value); } - public void add(Row row) { - min = Math.min(min, row.temperature); - max = Math.max(max, row.temperature); - sum += row.temperature; - count++; + return value; + } + + private static long getLongLittleEndian(long address) { + long value = UNSAFE.getLong(address); + + if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) { + value = Long.reverseBytes(value); } - public void merge(Aggregate right) { - min = Math.min(min, right.min); - max = Math.max(max, right.max); - sum += right.sum; - count += right.count; - } + return value; + } - @Override - public int compareTo(Aggregate that) { - byte[] lhs = this.station; - byte[] rhs = that.station; - int limit = Math.min(lhs.length, rhs.length); + private static String text(Map aggregates) { + StringBuilder text = new StringBuilder(aggregates.size() * 32 + 2); + text.append('{'); - for (int offset = 0; offset < limit; offset++) { - int left = lhs[offset]; - int right = rhs[offset]; - - if (left != right) { - return (left & 0xFF) - (right & 0xFF); - } + for (Map.Entry entry : aggregates.entrySet()) { + if (text.length() > 1) { + text.append(", "); } - return lhs.length - rhs.length; + Aggregate aggregate = entry.getValue(); + text.append(entry.getKey()).append('=') + .append(round(aggregate.min)).append('/') + .append(round(1.0 * aggregate.sum / aggregate.cnt)).append('/') + .append(round(aggregate.max)); } - @Override - public String toString() { - return new String(station) + "=" + round(min) + "/" + round(1.0 * sum / count) + "/" + round(max); - } + text.append('}'); + return text.toString(); + } - private static double round(double v) { - return Math.round(v) / 10.0; - } + private static double round(double v) { + return Math.round(v) / 10.0; + } + + private static class Row { + long address; + int length; + int hash; + int value; + } + + private record Aggregate(int min, int max, long sum, int cnt) { } private static class Aggregates { - private static final int GROW_FACTOR = 4; - private static final float LOAD_FACTOR = 0.55f; + private static final int SIZE = 16 * 1024; + private final long pointer; - private Aggregate[] aggregates = new Aggregate[1024]; - private int limit = (int) (aggregates.length * LOAD_FACTOR); - private int size; + public Aggregates() { + int size = 32 * SIZE; + long address = UNSAFE.allocateMemory(size + 8096); + pointer = (address + 4095) & (~4095); + UNSAFE.setMemory(pointer, size, (byte) 0); - public int size() { - return size; - } - - public void visit(Consumer consumer) { - if (size > 0) { - for (Aggregate aggregate : aggregates) { - if (aggregate != null) { - consumer.accept(aggregate); - } - } + long word = pack(Short.MAX_VALUE, Short.MIN_VALUE, 0); + for (int i = 0; i < SIZE; i++) { + long entry = pointer + 32 * i; + UNSAFE.putLong(entry + 24, word); } } public void add(Row row) { - int index = row.hash & (aggregates.length - 1); + long index = index(row.hash); + long header = ((long) row.hash << 32) | (row.length); while (true) { - Aggregate aggregate = aggregates[index]; + long address = pointer + (index << 5); + long head = UNSAFE.getLong(address); + long ref = UNSAFE.getLong(address + 8); + boolean isHit = (head == 0) || (head == header && equal(ref, row.address, row.length)); - if (aggregate == null) { - aggregates[index] = new Aggregate(row); - if (++size >= limit) { - grow(); - } + if (isHit) { + long sum = UNSAFE.getLong(address + 16) + row.value; + long word = UNSAFE.getLong(address + 24); + int min = Math.min(min(word), row.value); + int max = Math.max(max(word), row.value); + int cnt = cnt(word) + 1; + + UNSAFE.putLong(address, header); + UNSAFE.putLong(address + 8, row.address); + UNSAFE.putLong(address + 16, sum); + UNSAFE.putLong(address + 24, pack(min, max, cnt)); break; } - if (row.hash == aggregate.hash && Arrays.equals(row.station, 0, row.length, aggregate.station, 0, aggregate.station.length)) { - aggregate.add(row); - break; - } - - index = (index + 1) & (aggregates.length - 1); + index = (index + 1) & (SIZE - 1); } } - public void merge(Aggregate right) { - int index = right.hash & (aggregates.length - 1); + public void merge(Aggregates rights) { + for (int rightIndex = 0; rightIndex < SIZE; rightIndex++) { + long rightAddress = rights.pointer + (rightIndex << 5); + long header = UNSAFE.getLong(rightAddress); + long reference = UNSAFE.getLong(rightAddress + 8); - while (true) { - Aggregate aggregate = aggregates[index]; + if (header == 0) { + continue; + } - if (aggregate == null) { - aggregates[index] = right; - if (++size >= limit) { - grow(); + int hash = (int) (header >>> 32); + int length = (int) (header); + long index = index(hash); + + while (true) { + long address = pointer + (index << 5); + long head = UNSAFE.getLong(address); + long ref = UNSAFE.getLong(address + 8); + boolean isHit = (head == 0) || (head == header && equal(ref, reference, length)); + + if (isHit) { + long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16); + long left = UNSAFE.getLong(address + 24); + long right = UNSAFE.getLong(rightAddress + 24); + int min = Math.min(min(left), min(right)); + int max = Math.max(max(left), max(right)); + int cnt = cnt(left) + cnt(right); + + UNSAFE.putLong(address, header); + UNSAFE.putLong(address + 8, reference); + UNSAFE.putLong(address + 16, sum); + UNSAFE.putLong(address + 24, pack(min, max, cnt)); + break; } - break; - } - if (right.hash == aggregate.hash && Arrays.equals(right.station, aggregate.station)) { - aggregate.merge(right); - break; + index = (index + 1) & (SIZE - 1); } - - index = (index + 1) & (aggregates.length - 1); } } - public Aggregates sort() { - Arrays.sort(aggregates, Comparator.nullsLast(Aggregate::compareTo)); - return this; + public Map aggregate() { + TreeMap set = new TreeMap<>(); + + for (int index = 0; index < SIZE; index++) { + long address = pointer + (index << 5); + long head = UNSAFE.getLong(address); + long ref = UNSAFE.getLong(address + 8); + + if (head == 0) { + continue; + } + + int length = (int) (head); + byte[] array = new byte[length]; + UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); + String key = new String(array); + + long sum = UNSAFE.getLong(address + 16); + long word = UNSAFE.getLong(address + 24); + + Aggregate aggregate = new Aggregate(min(word), max(word), sum, cnt(word)); + set.put(key, aggregate); + } + + return set; } - private void grow() { - Aggregate[] oldAggregates = aggregates; - aggregates = new Aggregate[oldAggregates.length * GROW_FACTOR]; - limit = (int) (aggregates.length * LOAD_FACTOR); + private static long pack(int min, int max, int cnt) { + return ((long) min << 48) | (((long) max & 0xFFFF) << 32) | cnt; + } - for (Aggregate aggregate : oldAggregates) { - if (aggregate != null) { - int index = aggregate.hash & (aggregates.length - 1); + private static int cnt(long word) { + return (int) word; + } - while (aggregates[index] != null) { - index = (index + 1) & (aggregates.length - 1); - } + private static int max(long word) { + return (short) (word >>> 32); + } - aggregates[index] = aggregate; + private static int min(long word) { + return (short) (word >>> 48); + } + + private static long index(int hash) { + return (hash ^ (hash >> 16)) & (SIZE - 1); + } + + private static boolean equal(long leftAddress, long rightAddress, int length) { + int index = 0; + + while (length > 8) { + long left = UNSAFE.getLong(leftAddress + index); + long right = UNSAFE.getLong(rightAddress + index); + + if (left != right) { + return false; } + + length -= 8; + index += 8; } + + int shift = 64 - (length << 3); + long left = getLongBigEndian(leftAddress + index) >>> shift; + long right = getLongBigEndian(rightAddress + index) >>> shift; + return (left == right); } } @@ -272,87 +325,88 @@ public class CalculateAverage_artsiomkorzun { Aggregates aggregates = new Aggregates(); Row row = new Row(); - try (FileChannel channel = FileChannel.open(FILE, StandardOpenOption.READ)) { - for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) { - aggregate(channel, segment, aggregates, row); - } - } - catch (Throwable e) { - throw new RuntimeException(e); + for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) { + aggregate(aggregates, row, segment); } while (!result.compareAndSet(null, aggregates)) { Aggregates rights = result.getAndSet(null); if (rights != null) { - aggregates = merge(aggregates, rights); + aggregates.merge(rights); } } } - private static void aggregate(FileChannel channel, int segment, Aggregates aggregates, Row row) throws Exception { + private static void aggregate(Aggregates aggregates, Row row, int segment) { long position = (long) SEGMENT_SIZE * segment; - int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, FILE_SIZE - position); - int limit = Math.min(SEGMENT_SIZE, size - 1); + int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, MAPPED_FILE.byteSize() - position); + long address = MAPPED_FILE.address() + position; + long limit = address + Math.min(SEGMENT_SIZE, size - 1); - MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, position, size); - - if (position > 0) { - next(buffer); + if (segment > 0) { + address = next(address); } - for (int offset = buffer.position(); offset <= limit;) { - offset = parse(buffer, row, offset); + while (address <= limit) { + // this parsing can produce seg fault at page boundaries + // e.g. file size is 4096 and the last entry is X=0.0, which is less than 8 bytes + // as a result a read will be split across pages, where one of them is not mapped + // but for some reason it works on my machine, leaving to investigate + address = parseKey(address, row); + address = parseValue(address, row); aggregates.add(row); } } - private static Aggregates merge(Aggregates lefts, Aggregates rights) { - if (rights.size() < lefts.size()) { - Aggregates temp = lefts; - lefts = rights; - rights = temp; - } - - rights.visit(lefts::merge); - return lefts; - } - - private static void next(ByteBuffer buffer) { - while (buffer.get() != '\n') { + private static long next(long address) { + while (UNSAFE.getByte(address++) != '\n') { // continue } + return address; } - private static int parse(ByteBuffer buffer, Row row, int offset) { - byte[] station = row.station; + // idea: royvanrijn + // explanation: https://richardstartin.github.io/posts/finding-bytes + private static long parseKey(long address, Row row) { int length = 0; - int hash = 0; + long hash = 0; + long word; - for (byte b; (b = buffer.get(offset++)) != ';';) { - station[length++] = b; - hash = 71 * hash + b; + while (true) { + word = getLongLittleEndian(address + length); + long match = word ^ COMMA_PATTERN; + long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L; + + if (mask == 0) { + hash = 71 * hash + word; + length += 8; + continue; + } + + int bit = Long.numberOfTrailingZeros(mask); + length += (bit >>> 3); + hash = 71 * hash + (word & (0x00FFFFFFFFFFFFFFL >>> (63 - bit))); + + row.address = address; + row.length = length; + row.hash = Long.hashCode(hash); + + return address + length + 1; } + } - row.length = length; - row.hash = hash; - - int sign = 1; - - if (buffer.get(offset) == '-') { - sign = -1; - offset++; - } - - int value = buffer.get(offset++) - '0'; - - if (buffer.get(offset) != '.') { - value = 10 * value + buffer.get(offset++) - '0'; - } - - value = 10 * value + buffer.get(offset + 1) - '0'; - row.temperature = value * sign; - return offset + 3; + // idea: merykitty + private static long parseValue(long address, Row row) { + long word = getLongLittleEndian(address); + long inverted = ~word; + int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS); + long signed = (inverted << 59) >> 63; + long mask = ~(signed & 0xFF); + long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L; + long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; + row.value = (int) ((abs ^ signed) - signed); + return address + (dot >> 3) + 3; } } }