diff --git a/calculate_average_artsiomkorzun.sh b/calculate_average_artsiomkorzun.sh index d9c1828..977b6e3 100755 --- a/calculate_average_artsiomkorzun.sh +++ b/calculate_average_artsiomkorzun.sh @@ -17,9 +17,9 @@ if [ -f target/CalculateAverage_artsiomkorzun_image ]; then echo "Picking up existing native image 'target/CalculateAverage_artsiomkorzun_image', delete the file to select JVM mode." 1>&2 - target/CalculateAverage_artsiomkorzun_image + target/CalculateAverage_artsiomkorzun_image -XX:MaxDirectMemorySize=4294967296 else - JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation" + JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation -XX:MaxDirectMemorySize=4294967296" echo "Chosing to run the app in JVM mode as no native image was found, use prepare_artsiomkorzun.sh to generate." 1>&2 java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_artsiomkorzun fi \ No newline at end of file diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java index c3c39ab..1373154 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java @@ -17,12 +17,12 @@ package dev.morling.onebrc; import sun.misc.Unsafe; -import java.lang.foreign.Arena; -import java.lang.foreign.MemorySegment; import java.lang.reflect.Field; +import java.nio.Buffer; +import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.StandardOpenOption; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.atomic.AtomicInteger; @@ -31,19 +31,21 @@ import java.util.concurrent.atomic.AtomicReference; public class CalculateAverage_artsiomkorzun { private static final Path FILE = Path.of("./measurements.txt"); - private static final long SEGMENT_SIZE = 32 * 1024 * 1024; - private static final long SEGMENT_OVERLAP = 1024; + private static final int SEGMENT_SIZE = 4 * 1024 * 1024; + private static final int SEGMENT_OVERLAP = 128; private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL; private static final long DOT_BITS = 0x10101000; private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); private static final Unsafe UNSAFE; + private static final long ADDRESS_OFFSET; static { try { Field unsafe = Unsafe.class.getDeclaredField("theUnsafe"); unsafe.setAccessible(true); UNSAFE = (Unsafe) unsafe.get(Unsafe.class); + ADDRESS_OFFSET = UNSAFE.objectFieldOffset(Buffer.class.getDeclaredField("address")); } catch (Throwable e) { throw new RuntimeException(e); @@ -62,9 +64,7 @@ public class CalculateAverage_artsiomkorzun { } private static void execute() throws Exception { - MemorySegment fileMemory = map(FILE); - long fileAddress = fileMemory.address(); - long fileSize = fileMemory.byteSize(); + long fileSize = Files.size(FILE); int segmentCount = (int) ((fileSize + SEGMENT_SIZE - 1) / SEGMENT_SIZE); AtomicInteger counter = new AtomicInteger(); @@ -74,7 +74,7 @@ public class CalculateAverage_artsiomkorzun { Aggregator[] aggregators = new Aggregator[parallelism]; for (int i = 0; i < aggregators.length; i++) { - aggregators[i] = new Aggregator(counter, result, fileAddress, fileSize, segmentCount); + aggregators[i] = new Aggregator(counter, result, segmentCount); aggregators[i].start(); } @@ -86,14 +86,16 @@ public class CalculateAverage_artsiomkorzun { System.out.println(text(aggregates)); } - private static MemorySegment map(Path file) { - try (FileChannel channel = FileChannel.open(file, StandardOpenOption.READ)) { - long size = channel.size(); - return channel.map(FileChannel.MapMode.READ_ONLY, 0, size, Arena.global()); - } - catch (Throwable e) { - throw new RuntimeException(e); - } + private static long address(ByteBuffer buffer) { + return UNSAFE.getLong(buffer, ADDRESS_OFFSET); + } + + private static ByteBuffer allocate(int size) { + ByteBuffer buffer = ByteBuffer.allocateDirect(size + 4096); + long address = address(buffer); + long aligned = (address + 4095) & (~4095); + int padding = (int) (aligned - address); + return buffer.position(padding).limit(padding + size).slice(); } private static long word(long address) { @@ -139,13 +141,8 @@ public class CalculateAverage_artsiomkorzun { private static final int ENTRIES = 64 * 1024; private static final int SIZE = 128 * ENTRIES; - private final long pointer; - - public Aggregates() { - long address = UNSAFE.allocateMemory(SIZE + 8096); - pointer = (address + 4095) & (~4095); - UNSAFE.setMemory(pointer, SIZE, (byte) 0); - } + private final ByteBuffer buffer = allocate(SIZE); + private final long pointer = address(buffer); public long find(long word, int hash) { long address = pointer + offset(hash); @@ -206,14 +203,8 @@ public class CalculateAverage_artsiomkorzun { for (int offset = offset(hash);; offset = next(offset)) { long address = pointer + offset; - int len = UNSAFE.getInt(address); - if (len == 0) { - UNSAFE.copyMemory(rightAddress, address, 24 + length); - break; - } - - if (len == length && equal(address + 24, rightAddress + 24, length)) { + if (equal(address + 24, rightAddress + 24, length)) { long sum = UNSAFE.getLong(address + 8) + UNSAFE.getLong(rightAddress + 8); int cnt = UNSAFE.getInt(address + 16) + UNSAFE.getInt(rightAddress + 16); short min = (short) Math.min(UNSAFE.getShort(address + 20), UNSAFE.getShort(rightAddress + 20)); @@ -225,6 +216,13 @@ public class CalculateAverage_artsiomkorzun { UNSAFE.putShort(address + 22, max); break; } + + int len = UNSAFE.getInt(address); + + if (len == 0) { + UNSAFE.copyMemory(rightAddress, address, length + 24); + break; + } } } } @@ -237,8 +235,8 @@ public class CalculateAverage_artsiomkorzun { int length = UNSAFE.getInt(address); if (length != 0) { - byte[] array = new byte[length]; - UNSAFE.copyMemory(null, address + 24, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); + byte[] array = new byte[length - 1]; + UNSAFE.copyMemory(null, address + 24, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, array.length); String key = new String(array); long sum = UNSAFE.getLong(address + 8); @@ -271,7 +269,7 @@ public class CalculateAverage_artsiomkorzun { } private static boolean equal(long leftAddress, long leftWord, long rightAddress, int length) { - while (length >= 8) { + while (length > 8) { long left = UNSAFE.getLong(leftAddress); long right = UNSAFE.getLong(rightAddress); @@ -309,35 +307,39 @@ public class CalculateAverage_artsiomkorzun { private final AtomicInteger counter; private final AtomicReference result; - private final long fileAddress; - private final long fileSize; - private final int segmentCount; + private final int segments; - public Aggregator(AtomicInteger counter, AtomicReference result, - long fileAddress, long fileSize, int segmentCount) { + public Aggregator(AtomicInteger counter, AtomicReference result, int segments) { super("aggregator"); this.counter = counter; this.result = result; - this.fileAddress = fileAddress; - this.fileSize = fileSize; - this.segmentCount = segmentCount; + this.segments = segments; } @Override public void run() { Aggregates aggregates = new Aggregates(); + ByteBuffer buffer = allocate(SEGMENT_SIZE + SEGMENT_OVERLAP); - for (int segment; (segment = counter.getAndIncrement()) < segmentCount;) { - long position = SEGMENT_SIZE * segment; - long size = Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, fileSize - position); - long address = fileAddress + position; - long limit = address + Math.min(SEGMENT_SIZE, size - 1); + try (FileChannel channel = FileChannel.open(FILE)) { + for (int segment; (segment = counter.getAndIncrement()) < segments;) { + buffer.clear(); - if (segment > 0) { - address = next(address); + long position = (long) SEGMENT_SIZE * segment; + int size = channel.read(buffer, position); + + long address = address(buffer); + long limit = address + Math.min(SEGMENT_SIZE, size - 1); + + if (segment > 0) { + address = next(address); + } + + aggregate(aggregates, address, limit); } - - aggregate(aggregates, address, limit); + } + catch (Throwable e) { + throw new RuntimeException(e); } while (!result.compareAndSet(null, aggregates)) { @@ -406,7 +408,7 @@ public class CalculateAverage_artsiomkorzun { ptr = aggregates.put(position, word, length, hash); } - position = update(ptr, position + length + 1); + position = update(ptr, position + length); } } @@ -431,12 +433,12 @@ public class CalculateAverage_artsiomkorzun { } private static long mask(long word, long separator) { - long mask = ((separator - 1) ^ separator) >>> 8; + long mask = separator ^ (separator - 1); return word & mask; } private static int length(long separator) { - return Long.numberOfTrailingZeros(separator) >>> 3; + return (Long.numberOfTrailingZeros(separator) >>> 3) + 1; } private static long next(long position) {