From dac38bc97fb1411d1ae7a1a354fe9c7ae0c659d2 Mon Sep 17 00:00:00 2001 From: Andrew Sun Date: Fri, 12 Jan 2024 14:42:22 -0500 Subject: [PATCH] Optimizations to Andrew Sun's entry (#310) Squashed commit of the following: commit 44d3736de87834b41118d45831e59fc2b052117c Merge: fcf795f 3127962 Author: Andrew Sun Date: Thu Jan 11 20:01:13 2024 -0500 Merge branch 'gunnarmorling:main' into as-com commit fcf795fbabacbd91891d11d21450ee4b1c479dc5 Author: Andrew Sun Date: Wed Jan 10 21:14:01 2024 -0500 Optimizations to Andrew Sun's entry commit 4203924711bab5252ff3cbb50a90f4ce4e8e67c2 Merge: 9aed05a 085168a Author: Andrew Sun Date: Wed Jan 10 19:40:19 2024 -0500 Merge remote-tracking branch 'upstream/main' into as-com commit 9aed05a04bd27fe7323e66c347b1011c77da322c Merge: 3f8df58 c2d120f Author: Andrew Sun Date: Sun Jan 7 16:45:27 2024 -0500 Merge remote-tracking branch 'origin/as-com' into as-com # Conflicts: # calculate_average_asun.sh # src/main/java/dev/morling/onebrc/CalculateAverage_asun.java commit c2d120f0cb7f18c720a81a7f898102b310f9ecb9 Author: Andrew Sun Date: Sat Jan 6 00:45:47 2024 -0500 Add entry by Andrew Sun commit 3f8df5803bcc8f3e29ed8bfff3077eb0e8cdab15 Author: Andrew Sun Date: Sat Jan 6 00:45:47 2024 -0500 Add entry by Andrew Sun --- calculate_average_asun.sh | 2 +- .../morling/onebrc/CalculateAverage_asun.java | 233 +++++++++++++----- 2 files changed, 167 insertions(+), 68 deletions(-) diff --git a/calculate_average_asun.sh b/calculate_average_asun.sh index f3f8502..94f072f 100755 --- a/calculate_average_asun.sh +++ b/calculate_average_asun.sh @@ -15,5 +15,5 @@ # limitations under the License. # -JAVA_OPTS="--enable-preview --add-modules jdk.incubator.vector -XX:+UnlockExperimentalVMOptions -Xms500m -Xmx500m -XX:CompilationMode=high-only" +JAVA_OPTS="--enable-preview --add-modules jdk.incubator.vector -XX:+UnlockExperimentalVMOptions -XX:ActiveProcessorCount=8 -Xms500m -Xmx500m -XX:CompilationMode=high-only -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_asun diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java index 88a90ea..0f5b0da 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java @@ -16,12 +16,16 @@ package dev.morling.onebrc; import jdk.incubator.vector.*; +import sun.misc.Unsafe; import java.io.File; import java.io.IOException; -import java.io.RandomAccessFile; +import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.lang.reflect.Field; import java.nio.ByteOrder; import java.nio.channels.FileChannel; import java.nio.file.Files; @@ -31,8 +35,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.TreeMap; +import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ExecutionException; -import java.util.stream.Collectors; +import java.util.concurrent.atomic.AtomicLongArray; // based on spullara's submission @@ -53,26 +58,72 @@ public class CalculateAverage_asun { ASC = ByteVector.fromArray(BYTE_SPECIES, bytes, 0); } - public static void main(String[] args) throws IOException, ExecutionException, InterruptedException { - long start = System.currentTimeMillis(); - var filename = args.length == 0 ? FILE : args[0]; - var file = new File(filename); + private static final Unsafe UNSAFE; - List fileSegments = getFileSegments(file); - // System.out.println(System.currentTimeMillis() - start); - var resultsMap = fileSegments.stream().map(segment -> { + static { + try { + Field f = Unsafe.class.getDeclaredField("theUnsafe"); + f.setAccessible(true); + UNSAFE = (Unsafe) f.get(null); + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private static AtomicLongArray segmentQueue; + @SuppressWarnings("FieldMayBeFinal") + // @jdk.internal.vm.annotation.Contended + private static volatile int head = 0; + @SuppressWarnings("FieldMayBeFinal") + // @jdk.internal.vm.annotation.Contended + private static volatile int tail = 0; + @SuppressWarnings("FieldMayBeFinal") + // @jdk.internal.vm.annotation.Contended + private static volatile boolean doneQueueing = false; + + private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); + private static final VarHandle headHandle; + private static final VarHandle tailHandle; + private static final VarHandle doneHandle; + + static { + try { + headHandle = LOOKUP.findStaticVarHandle(CalculateAverage_asun.class, "head", int.class); + tailHandle = LOOKUP.findStaticVarHandle(CalculateAverage_asun.class, "tail", int.class); + doneHandle = LOOKUP.findStaticVarHandle(CalculateAverage_asun.class, "doneQueueing", boolean.class); + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private static final ArrayBlockingQueue workerOutput = new ArrayBlockingQueue<>(Runtime.getRuntime().availableProcessors()); + + private static class Worker implements Runnable { + private long segmentStart; + private long segmentEnd; + + private final MemorySegment ms; + + private Worker(MemorySegment ms) { + this.ms = ms; + } + + @Override + public void run() { var resultMap = new ByteArrayToResultMap(); - long segmentEnd = segment.end(); - try (var fileChannel = (FileChannel) Files.newByteChannel(Path.of(filename), StandardOpenOption.READ)) { - var bb = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segmentEnd - segment.start()); - var ms = MemorySegment.ofBuffer(bb); + var ms = this.ms.asSlice(0); + var msAddr = ms.address(); + var actualLimit = ms.byteSize(); + var buffer = new byte[100 + VECTOR_SIZE]; - // Up to 100 characters for a city name - var buffer = new byte[100 + VECTOR_SIZE]; + while (pollSegment()) { long startLine; - long pos = 0; - long limit = ms.byteSize(); - long vectorLimit = limit - VECTOR_SIZE; + long pos = segmentStart; + long limit = segmentEnd; + long vectorLimit = Math.min(limit, actualLimit - VECTOR_SIZE); + long longLimit = Math.min(limit, actualLimit - 8); // int[] lastHashMult = new int[]{ 7, 31, 63, 15, 255, 127, 3, 511 }; // IntVector lastMul = IntVector.fromArray(INT_SPECIES, lastHashMult, 0); @@ -117,12 +168,15 @@ public class CalculateAverage_asun { int nameLen = (int) (currentPosition - startLine); currentPosition++; - if (currentPosition >= limit - 8) { + if (currentPosition >= longLimit) { break; } - long g = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, currentPosition); - int negative = (g & 0xff) == '-' ? -1 : 1; + long g = UNSAFE.getLong(msAddr + currentPosition); + // long g = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, currentPosition); + boolean minus = (g & 0xff) == '-'; + long minusL = (minus ? 1L : 0L) - 1; + int negative = minus ? -1 : 1; // 00101101 MINUS // 00101110 PERIOD @@ -136,17 +190,13 @@ public class CalculateAverage_asun { int tzc = Long.numberOfTrailingZeros(lf); long bytesToLF = tzc / 8; - int shift = 64 - tzc & 0b111000; + int shift = 72 - tzc & 0b111000; - long reversedDigits = Long.reverseBytes(g) >> shift; - long digitBits = reversedDigits & (0x1010101010101010L >> shift); - long digitsExt = (digitBits >> 1 | digitBits >> 2 | digitBits >> 3 | digitBits >> 4); + long reversedDigits = Long.reverseBytes(g & (0xFFFFFFFFFFFFFF00L | minusL)) >> shift; - long digitsOnly = Long.compress(reversedDigits, digitsExt); - - long temp = (digitsOnly & 0xf) - + 10 * ((digitsOnly >> 4) & 0xf) - + 100 * ((digitsOnly >> 8) & 0xf); + long temp = (reversedDigits & 0xf) + + 10 * ((reversedDigits >> 16) & 0xf) + + 100 * ((reversedDigits >> 24) & 0xf); temp *= negative; @@ -194,13 +244,88 @@ public class CalculateAverage_asun { resultMap.putOrMerge(buffer, 0, offset, temp, hash); pos = currentPosition; } - return resultMap; } - catch (IOException e) { - throw new RuntimeException(e); + + workerOutput.add(resultMap); + } + + private boolean pollSegment() { + int head; + int tail; + + do { + head = (int) headHandle.getAcquire(); + tail = (int) tailHandle.getAcquire(); + + while (head >= tail) { + if ((boolean) doneHandle.getAcquire()) { + return false; + } + + head = (int) headHandle.getAcquire(); + tail = (int) tailHandle.getAcquire(); + } + } while (!headHandle.compareAndSet(head, head + 1)); + + segmentStart = segmentQueue.getPlain(head * 2); + segmentEnd = segmentQueue.getPlain(head * 2 + 1); + + return true; + } + + } + + public static void main(String[] args) throws IOException, ExecutionException, InterruptedException { + // long start = System.currentTimeMillis(); + var filename = args.length == 0 ? FILE : args[0]; + var file = new File(filename); + + @SuppressWarnings("resource") + var fileChannel = (FileChannel) Files.newByteChannel(Path.of(filename), StandardOpenOption.READ); + var ms = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size(), Arena.global()); + + long fileSize = file.length(); + long segmentSize = 10_000_000; + int numberOfSegments = (int) (file.length() / segmentSize + 1) * 2; + segmentQueue = new AtomicLongArray(numberOfSegments); + int tail = 0; + + int processors = Runtime.getRuntime().availableProcessors(); + + Thread.ofPlatform().daemon().start(() -> { + for (int i = 0; i < processors - 1; i++) { + Thread.ofPlatform().daemon().start(new Worker(ms)); } - }).parallel().flatMap(partition -> partition.getAll().stream()) - .collect(Collectors.toMap(e -> new String(e.key()), Entry::value, CalculateAverage_asun::merge, TreeMap::new)); + + new Worker(ms).run(); + }); + + long segStart = 0; + while (segStart < fileSize) { + long segEnd = findSegment(ms, Math.min(segStart + segmentSize, fileSize), fileSize); + segmentQueue.setRelease(tail * 2, segStart); + segmentQueue.setRelease(tail * 2 + 1, segEnd); + tailHandle.setRelease(++tail); + + segStart = segEnd; + } + + doneHandle.setRelease(true); + + // System.out.println(System.currentTimeMillis() - start); + + var resultsMap = new TreeMap(); + for (int i = 0; i < processors; i++) { + var result = workerOutput.take(); + + // System.out.println(i + " " + (System.currentTimeMillis() - start)); + + for (Entry e : result.getAll()) { + resultsMap.merge(new String(e.key()), e.value(), CalculateAverage_asun::merge); + } + + // System.out.println(i + " " + (System.currentTimeMillis() - start)); + } System.out.println(resultsMap); @@ -209,29 +334,6 @@ public class CalculateAverage_asun { Runtime.getRuntime().halt(0); } - private static List getFileSegments(File file) throws IOException { - int numberOfSegments = Runtime.getRuntime().availableProcessors() * 8; - long fileSize = file.length(); - long segmentSize = fileSize / numberOfSegments; - List segments = new ArrayList<>(numberOfSegments); - // Pointless to split small files - if (segmentSize < 1_000_000) { - segments.add(new FileSegment(0, fileSize)); - return segments; - } - try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) { - for (int i = 0; i < numberOfSegments; i++) { - long segStart = i * segmentSize; - long segEnd = (i == numberOfSegments - 1) ? fileSize : segStart + segmentSize; - segStart = findSegment(i, 0, randomAccessFile, segStart, segEnd); - segEnd = findSegment(i, numberOfSegments - 1, randomAccessFile, segEnd, fileSize); - - segments.add(new FileSegment(segStart, segEnd)); - } - } - return segments; - } - private static Result merge(Result v, Result value) { return merge(v, value.min, value.max, value.sum, value.count); } @@ -244,14 +346,14 @@ public class CalculateAverage_asun { return v; } - private static long findSegment(int i, int skipSegment, RandomAccessFile raf, long location, long fileSize) throws IOException { - if (i != skipSegment) { - raf.seek(location); - while (location < fileSize) { + private static long findSegment(MemorySegment ms, long location, long fileSize) { + while (location < fileSize) { + if (ms.get(ValueLayout.JAVA_BYTE, location) == '\n') { location++; - if (raf.read() == '\n') - break; + break; } + + location++; } return location; } @@ -283,9 +385,6 @@ public class CalculateAverage_asun { record Entry(byte[] key, Result value) { } - record FileSegment(long start, long end) { - } - static class ByteArrayToResultMap { public static final int MAPSIZE = 1024 * 128; Result[] slots = new Result[MAPSIZE];