diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java b/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java index a7baf9b..4c0351a 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java @@ -15,136 +15,386 @@ */ package dev.morling.onebrc; +import sun.misc.Unsafe; + import java.io.IOException; +import java.io.RandomAccessFile; +import java.lang.reflect.Field; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.channels.AsynchronousFileChannel; import java.nio.channels.CompletionHandler; import java.nio.charset.StandardCharsets; -import java.nio.file.Paths; +import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.StandardOpenOption; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.FutureTask; -import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.LinkedTransferQueue; +import java.util.concurrent.TransferQueue; /** - * Calculates the average using AIO and multiple threads. + * Results on Mac mini (Apple M2 with 8-core CPU / 8GB unified memory): + *
+ *   using AIO and multiple threads:
+ *     120.15s user 4.33s system 710% cpu 17.522 total
+ *
+ *   reduce the number of memory copies:
+ *      45.87s user 2.82s system 530% cpu  9.185 total
+ *
+ *   processing byte array backwards and using bitwise operation to find specific byte (inspired by thomaswue):
+ *      25.38s user 3.44s system 342% cpu  8.406 total
+ * 
* * @author Xylitol */ +@SuppressWarnings("unchecked") public class CalculateAverage_C5H12O5 { - private static final int BUFFER_CAPACITY = 1024 * 1024 * 10; - private static final int MAP_CAPACITY = 10000; - private static final int PROCESSORS = Runtime.getRuntime().availableProcessors(); - private static final BlockingQueue BYTES_QUEUE = new LinkedBlockingQueue<>(PROCESSORS); - private static long readPosition; + private static final int AVAILABLE_PROCESSOR_NUM = Runtime.getRuntime().availableProcessors(); + private static final int TRANSFER_QUEUE_CAPACITY = 1024 / 16 / AVAILABLE_PROCESSOR_NUM; // 1GB memory max + private static final int BYTE_BUFFER_CAPACITY = 1024 * 1024 * 16; // 16MB one time + private static final int EXPECTED_MAPPINGS_NUM = 10000; + + /** + * Fragment the file into chunks. + */ + private static long[] fragment(Path path) throws IOException { + long size = Files.size(path); + long chunk = size / AVAILABLE_PROCESSOR_NUM; + List positions = new ArrayList<>(); + try (RandomAccessFile file = new RandomAccessFile(path.toFile(), "r")) { + long position = chunk; + for (int i = 0; i < AVAILABLE_PROCESSOR_NUM - 1; i++) { + if (position >= size) { + break; + } + file.seek(position); + // move the position to the next newline byte + while (file.read() != '\n') { + position++; + } + positions.add(++position); + position += chunk; + } + } + if (positions.isEmpty() || positions.getLast() < size) { + positions.add(size); + } + return positions.stream().mapToLong(Long::longValue).toArray(); + } public static void main(String[] args) throws Exception { - System.out.println(calc("./measurements.txt")); + // fragment the input file + Path path = Path.of("./measurements.txt"); + long[] positions = fragment(path); + + // start the calculation tasks + FutureTask>[] tasks = new FutureTask[positions.length]; + for (int i = 0; i < positions.length; i++) { + tasks[i] = new FutureTask<>(new Calculator(path, (i == 0 ? 0 : positions[i - 1]), positions[i])); + new Thread(tasks[i]).start(); + } + + // wait for the results + Map result = HashMap.newHashMap(EXPECTED_MAPPINGS_NUM); + for (FutureTask> task : tasks) { + task.get().forEach((k, v) -> result.merge(k, v, MeasurementData::merge)); + } + + // sort and print the results + TreeMap sorted = new TreeMap<>(); + for (Map.Entry entry : result.entrySet()) { + sorted.put(new String(entry.getKey().bytes, StandardCharsets.UTF_8), entry.getValue()); + } + System.out.println(sorted); } /** - * Calculate the average. + * The calculation task. */ - public static String calc(String path) throws IOException, ExecutionException, InterruptedException { - readPosition = 0; - Map result = HashMap.newHashMap(MAP_CAPACITY); - // read and offer to queue - try (AsynchronousFileChannel channel = AsynchronousFileChannel.open( - Paths.get(path), Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor())) { - ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY); - channel.read(buffer, readPosition, buffer, new CompletionHandler<>() { + private static class Calculator implements Callable> { + private final TransferQueue transfer = new LinkedTransferQueue<>(); + private final AsynchronousFileChannel asyncChannel; + private final long limit; + private long position; + + public Calculator(Path file, long position, long limit) throws IOException { + ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor(); + this.asyncChannel = AsynchronousFileChannel.open(file, Set.of(StandardOpenOption.READ), executor); + this.position = position; + this.limit = limit; + } + + @Override + public Map call() throws InterruptedException { + ByteBuffer buffer = ByteBuffer.allocateDirect(BYTE_BUFFER_CAPACITY); + asyncChannel.read(buffer, position, buffer, new CompletionHandler<>() { @Override - public void completed(Integer bytesRead, ByteBuffer buffer) { - try { - if (bytesRead > 0) { - for (int i = buffer.position() - 1; i >= 0; i--) { - if (buffer.get(i) == '\n') { - buffer.limit(i + 1); - break; - } - } - buffer.flip(); - byte[] bytes = new byte[buffer.remaining()]; - buffer.get(bytes); - readPosition += buffer.limit(); - BYTES_QUEUE.put(bytes); - buffer.clear(); - channel.read(buffer, readPosition, buffer, this); - } - else { - for (int i = 0; i < PROCESSORS; i++) { - BYTES_QUEUE.put(new byte[0]); + public void completed(Integer readSize, ByteBuffer buffer) { + if (position + readSize >= limit) { + buffer.limit(readSize - (int) (position + readSize - limit)); + } + else { + for (int i = buffer.position() - 1; i >= 0; i--) { + if (buffer.get(i) == '\n') { + // truncate the buffer to the last newline byte + buffer.limit(i + 1); + break; } } } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); + buffer.flip(); + byte[] bytes = new byte[buffer.limit() + 1]; + // add a newline byte at the beginning + bytes[0] = '\n'; + buffer.get(bytes, 1, buffer.limit()); + transfer(bytes); + if ((position += buffer.limit()) < limit) { + buffer.clear(); + asyncChannel.read(buffer, position, buffer, this); + } + else { + // stop signal + transfer(new byte[0]); } } @Override public void failed(Throwable exc, ByteBuffer buffer) { - // ignore + transfer(new byte[0]); } }); + return process(); + } - @SuppressWarnings("unchecked") - FutureTask>[] tasks = new FutureTask[PROCESSORS]; - for (int i = 0; i < PROCESSORS; i++) { - tasks[i] = new FutureTask<>(new Task()); - new Thread(tasks[i]).start(); + /** + * Transfer or put the bytes to the queue. + */ + private void transfer(byte[] bytes) { + try { + if (transfer.size() >= TRANSFER_QUEUE_CAPACITY) { + transfer.transfer(bytes); + } + else { + transfer.put(bytes); + } } - for (FutureTask> task : tasks) { - task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge)); + catch (InterruptedException e) { + throw new RuntimeException(e); } } - return new TreeMap<>(result).toString(); + + /** + * Take and process the bytes from the queue. + */ + private Map process() throws InterruptedException { + Map result = HashMap.newHashMap(EXPECTED_MAPPINGS_NUM); + for (byte[] bytes = transfer.take(); bytes.length > 0; bytes = transfer.take()) { + Station station = new Station(bytes); + // read the bytes backwards + for (int position = bytes.length - 2; position >= 1; position--) { + + // calculate the temperature value + int temperature = bytes[position] - '0' + (bytes[position -= 2] - '0') * 10; + byte unknownByte = bytes[--position]; + int semicolon = switch (unknownByte) { + case ';' -> position; + case '-' -> { + temperature = -temperature; + yield --position; + } + default -> { + temperature += (unknownByte - '0') * 100; + if (bytes[--position] == '-') { + temperature = -temperature; + --position; + } + yield position; + } + }; + + // calculate the station name hash + int hash = 1; + while (true) { + long temp = LineFinder.previousLong(bytes, position); + int distance = LineFinder.NATIVE.fromRight(temp); + if (distance == 0) { + // current byte is '\n' + break; + } + position -= distance; + if (distance == 8) { + // can't find '\n' in previous 8 bytes + hash = 31 * hash + (int) (temp ^ (temp >>> 32)); + continue; + } + // clear the redundant bytes + temp = LineFinder.NATIVE.clearLeft(temp, distance); + hash = 31 * hash + (int) (temp ^ (temp >>> 32)); + } + + // merge data to the result map + MeasurementData data = result.get(station.slice(hash, position + 1, semicolon)); + if (data == null) { + result.put(station.copy(), new MeasurementData(temperature)); + } else { + data.merge(temperature); + } + } + } + return result; + } } /** - * The measurement name. + * To find the nearest newline byte position in a long. */ - private record MeasurementName(byte[] bytes, int length) { + private interface LineFinder { + // choose the implementation according to the native byte order + LineFinder NATIVE = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? LELineFinder.INST : BELineFinder.INST; + + Unsafe UNSAFE = initUnsafe(); + int BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + int LONG_BYTES = Long.SIZE / Byte.SIZE; + + static Unsafe initUnsafe() { + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + return (Unsafe) theUnsafe.get(Unsafe.class); + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + static long previousLong(byte[] bytes, long offset) { + return UNSAFE.getLong(bytes, BYTE_ARRAY_BASE_OFFSET + offset + 1 - LONG_BYTES); + } + + /** + * Mark the highest bit of newline byte (0x0A) to 1. + */ + static long markHighestBit(long longBytes) { + long temp = longBytes ^ 0x0A0A0A0A0A0A0A0AL; + return (temp - 0x0101010101010101L) & ~temp & 0x8080808080808080L; + } + + /** + * Find the nearest newline byte position from right to left. + */ + int fromRight(long longBytes); + + /** + * Clear the left bytes out of the range. + */ + long clearLeft(long longBytes, int keepNum); + + enum LELineFinder implements LineFinder { + INST; + + private static final long[] MASKS = new long[8]; + + static { + for (int i = 1; i <= 7; i++) { + MASKS[i] = 0xFFFFFFFFFFFFFFFFL << ((8 - i) << 3); + } + } + + @Override + public int fromRight(long longBytes) { + return Long.numberOfLeadingZeros(markHighestBit(longBytes)) >>> 3; + } + + @Override + public long clearLeft(long longBytes, int keepNum) { + return longBytes & MASKS[keepNum]; + } + } + + enum BELineFinder implements LineFinder { + INST; + + private static final long[] MASKS = new long[8]; + + static { + for (int i = 1; i <= 7; i++) { + MASKS[i] = 0xFFFFFFFFFFFFFFFFL >>> ((8 - i) << 3); + } + } + + @Override + public int fromRight(long longBytes) { + return Long.numberOfTrailingZeros(markHighestBit(longBytes)) >>> 3; + } + + @Override + public long clearLeft(long longBytes, int keepNum) { + return longBytes & MASKS[keepNum]; + } + } + } + + /** + * The station name wrapper ( bytes[from, to) ). + */ + private static class Station { + private final byte[] bytes; + private int from; + private int to; + private int hash; + + public Station(byte[] bytes) { + this(bytes, 0, 0, 0); + } + + public Station(byte[] bytes, int hash, int from, int to) { + this.bytes = bytes; + this.slice(hash, from, to); + } + + public Station slice(int hash, int from, int to) { + this.hash = hash; + this.from = from; + this.to = to; + return this; + } + + public Station copy() { + int length = to - from; + byte[] newBytes = new byte[length]; + System.arraycopy(bytes, from, newBytes, 0, length); + return new Station(newBytes, hash, 0, length); + } @Override - public boolean equals(Object name) { - MeasurementName other = (MeasurementName) name; - if (other.length != length) { - return false; - } - return Arrays.compare(bytes, 0, length, other.bytes, 0, length) == 0; + public boolean equals(Object station) { + Station other = (Station) station; + return Arrays.equals(bytes, from, to, other.bytes, other.from, other.to); } @Override public int hashCode() { - int result = 1; - for (int i = 0; i < length; i++) { - result = 31 * result + bytes[i]; - } - return result; - } - - @Override - public String toString() { - return new String(bytes, 0, length, StandardCharsets.UTF_8); + return hash; } } /** - * The measurement data. + * The measurement data wrapper ( temperature * 10 ). */ private static class MeasurementData { private int min; private int max; - private int sum; + private long sum; private int count; public MeasurementData(int value) { @@ -154,11 +404,15 @@ public class CalculateAverage_C5H12O5 { this.count = 1; } - public MeasurementData merge(MeasurementData data) { - return merge(data.min, data.max, data.sum, data.count); + public MeasurementData merge(int value) { + return merge(value, value, value, 1); } - public MeasurementData merge(int min, int max, int sum, int count) { + public MeasurementData merge(MeasurementData other) { + return merge(other.min, other.max, other.sum, other.count); + } + + public MeasurementData merge(int min, int max, long sum, int count) { this.min = Math.min(this.min, min); this.max = Math.max(this.max, max); this.sum += sum; @@ -168,67 +422,7 @@ public class CalculateAverage_C5H12O5 { @Override public String toString() { - return (min / 10.0) + "/" + (Math.round((double) sum / count) / 10.0) + "/" + (max / 10.0); - } - } - - /** - * The task to calculate. - */ - private static class Task implements Callable> { - - @Override - public Map call() throws InterruptedException { - // poll from queue and calculate - Map result = HashMap.newHashMap(MAP_CAPACITY); - for (byte[] bytes = BYTES_QUEUE.take(); true; bytes = BYTES_QUEUE.take()) { - if (bytes.length == 0) { - break; - } - int start = 0; - for (int end = 0; end < bytes.length; end++) { - if (bytes[end] == '\n') { - byte[] newBytes = new byte[end - start]; - System.arraycopy(bytes, start, newBytes, 0, newBytes.length); - int semicolon = newBytes.length - 4; - for (; semicolon >= 0; semicolon--) { - if (newBytes[semicolon] == ';') { - break; - } - } - MeasurementName station = new MeasurementName(newBytes, semicolon); - int value = toInt(newBytes, semicolon + 1); - MeasurementData data = result.get(station); - if (data != null) { - data.merge(value, value, value, 1); - } - else { - result.put(station, new MeasurementData(value)); - } - start = end + 1; - } - } - } - return result; - } - - /** - * Convert the byte array to int. - */ - private static int toInt(byte[] bytes, int start) { - boolean negative = false; - int result = 0; - for (int i = start; i < bytes.length; i++) { - byte b = bytes[i]; - if (b == '-') { - negative = true; - continue; - } - if (b != '.') { - result = result * 10 + (b - '0'); - } - } - return negative ? -result : result; + return STR."\{min / 10.0}/\{Math.round((double) sum / count) / 10.0}/\{max / 10.0}"; } } }