From ac4805ee45d9d626d0ef93d3cbe6191b81d9e393 Mon Sep 17 00:00:00 2001 From: Artsiom Korzun <72259616+artsiomkorzun@users.noreply.github.com> Date: Sun, 21 Jan 2024 20:23:48 +0100 Subject: [PATCH] subprocess spawner (#542) --- calculate_average_artsiomkorzun.sh | 4 +- prepare_artsiomkorzun.sh | 2 +- .../CalculateAverage_artsiomkorzun.java | 115 +++++++++++------- 3 files changed, 76 insertions(+), 45 deletions(-) diff --git a/calculate_average_artsiomkorzun.sh b/calculate_average_artsiomkorzun.sh index 977b6e3..d9c1828 100755 --- a/calculate_average_artsiomkorzun.sh +++ b/calculate_average_artsiomkorzun.sh @@ -17,9 +17,9 @@ if [ -f target/CalculateAverage_artsiomkorzun_image ]; then echo "Picking up existing native image 'target/CalculateAverage_artsiomkorzun_image', delete the file to select JVM mode." 1>&2 - target/CalculateAverage_artsiomkorzun_image -XX:MaxDirectMemorySize=4294967296 + target/CalculateAverage_artsiomkorzun_image else - JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation -XX:MaxDirectMemorySize=4294967296" + JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation" echo "Chosing to run the app in JVM mode as no native image was found, use prepare_artsiomkorzun.sh to generate." 1>&2 java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_artsiomkorzun fi \ No newline at end of file diff --git a/prepare_artsiomkorzun.sh b/prepare_artsiomkorzun.sh index 9ae693a..9840486 100755 --- a/prepare_artsiomkorzun.sh +++ b/prepare_artsiomkorzun.sh @@ -16,7 +16,7 @@ # source "$HOME/.sdkman/bin/sdkman-init.sh" -sdk use java 21.0.1-graal 1>&2 +sdk use java 21.0.2-graal 1>&2 if [ ! -f target/CalculateAverage_artsiomkorzun_image ]; then NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native -R:MaxHeapSize=64m --enable-preview --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_artsiomkorzun" diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java index ca76d10..40b8db0 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java @@ -17,12 +17,14 @@ package dev.morling.onebrc; import sun.misc.Unsafe; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; import java.lang.reflect.Field; -import java.nio.Buffer; -import java.nio.ByteBuffer; import java.nio.channels.FileChannel; -import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.atomic.AtomicInteger; @@ -31,21 +33,19 @@ import java.util.concurrent.atomic.AtomicReference; public class CalculateAverage_artsiomkorzun { private static final Path FILE = Path.of("./measurements.txt"); - private static final int SEGMENT_SIZE = 4 * 1024 * 1024; - private static final int SEGMENT_OVERLAP = 128; + private static final long SEGMENT_SIZE = 4 * 1024 * 1024; + private static final long SEGMENT_OVERLAP = 128; private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL; private static final long DOT_BITS = 0x10101000; private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); private static final Unsafe UNSAFE; - private static final long ADDRESS_OFFSET; static { try { Field unsafe = Unsafe.class.getDeclaredField("theUnsafe"); unsafe.setAccessible(true); UNSAFE = (Unsafe) unsafe.get(Unsafe.class); - ADDRESS_OFFSET = UNSAFE.objectFieldOffset(Buffer.class.getDeclaredField("address")); } catch (Throwable e) { throw new RuntimeException(e); @@ -60,11 +60,42 @@ public class CalculateAverage_artsiomkorzun { // System.err.println("Time: " + (end - start)); // } + if (isSpawn(args)) { + spawn(); + return; + } + execute(); } + private static boolean isSpawn(String[] args) { + for (String arg : args) { + if ("--worker".equals(arg)) { + return false; + } + } + + return true; + } + + private static void spawn() throws Exception { + ProcessHandle.Info info = ProcessHandle.current().info(); + ArrayList commands = new ArrayList<>(); + info.command().ifPresent(commands::add); + info.arguments().ifPresent(args -> commands.addAll(Arrays.asList(args))); + commands.add("--worker"); + + new ProcessBuilder() + .command(commands) + .start() + .getInputStream() + .transferTo(System.out); + } + private static void execute() throws Exception { - long fileSize = Files.size(FILE); + MemorySegment fileMemory = map(FILE); + long fileAddress = fileMemory.address(); + long fileSize = fileMemory.byteSize(); int segmentCount = (int) ((fileSize + SEGMENT_SIZE - 1) / SEGMENT_SIZE); AtomicInteger counter = new AtomicInteger(); @@ -74,7 +105,7 @@ public class CalculateAverage_artsiomkorzun { Aggregator[] aggregators = new Aggregator[parallelism]; for (int i = 0; i < aggregators.length; i++) { - aggregators[i] = new Aggregator(counter, result, segmentCount); + aggregators[i] = new Aggregator(counter, result, fileAddress, fileSize, segmentCount); aggregators[i].start(); } @@ -84,18 +115,17 @@ public class CalculateAverage_artsiomkorzun { Map aggregates = result.get().aggregate(); System.out.println(text(aggregates)); + System.out.close(); } - private static long address(ByteBuffer buffer) { - return UNSAFE.getLong(buffer, ADDRESS_OFFSET); - } - - private static ByteBuffer allocate(int size) { - ByteBuffer buffer = ByteBuffer.allocateDirect(size + 4096); - long address = address(buffer); - long aligned = (address + 4095) & (~4095); - int padding = (int) (aligned - address); - return buffer.position(padding).limit(padding + size).slice(); + private static MemorySegment map(Path file) { + try (FileChannel channel = FileChannel.open(file, StandardOpenOption.READ)) { + long size = channel.size(); + return channel.map(FileChannel.MapMode.READ_ONLY, 0, size, Arena.global()); + } + catch (Throwable e) { + throw new RuntimeException(e); + } } private static long word(long address) { @@ -142,8 +172,13 @@ public class CalculateAverage_artsiomkorzun { private static final int SIZE = 128 * ENTRIES; private static final int MASK = (ENTRIES - 1) << 7; - private final ByteBuffer buffer = allocate(SIZE); - private final long pointer = address(buffer); + private final long pointer; + + public Aggregates() { + long address = UNSAFE.allocateMemory(SIZE + 4096); + pointer = (address + 4095) & (~4095); + UNSAFE.setMemory(pointer, SIZE, (byte) 0); + } public long find(long word, int hash) { long address = pointer + offset(hash); @@ -308,39 +343,35 @@ public class CalculateAverage_artsiomkorzun { private final AtomicInteger counter; private final AtomicReference result; - private final int segments; + private final long fileAddress; + private final long fileSize; + private final int segmentCount; - public Aggregator(AtomicInteger counter, AtomicReference result, int segments) { + public Aggregator(AtomicInteger counter, AtomicReference result, + long fileAddress, long fileSize, int segmentCount) { super("aggregator"); this.counter = counter; this.result = result; - this.segments = segments; + this.fileAddress = fileAddress; + this.fileSize = fileSize; + this.segmentCount = segmentCount; } @Override public void run() { Aggregates aggregates = new Aggregates(); - ByteBuffer buffer = allocate(SEGMENT_SIZE + SEGMENT_OVERLAP); - try (FileChannel channel = FileChannel.open(FILE)) { - for (int segment; (segment = counter.getAndIncrement()) < segments;) { - buffer.clear(); + for (int segment; (segment = counter.getAndIncrement()) < segmentCount;) { + long position = SEGMENT_SIZE * segment; + long size = Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, fileSize - position); + long address = fileAddress + position; + long limit = address + Math.min(SEGMENT_SIZE, size - 1); - long position = (long) SEGMENT_SIZE * segment; - int size = channel.read(buffer, position); - - long address = address(buffer); - long limit = address + Math.min(SEGMENT_SIZE, size - 1); - - if (segment > 0) { - address = next(address); - } - - aggregate(aggregates, address, limit); + if (segment > 0) { + address = next(address); } - } - catch (Throwable e) { - throw new RuntimeException(e); + + aggregate(aggregates, address, limit); } while (!result.compareAndSet(null, aggregates)) {