diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java b/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java index 0764b65..a7baf9b 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java @@ -16,19 +16,14 @@ package dev.morling.onebrc; import java.io.IOException; -import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousFileChannel; import java.nio.channels.CompletionHandler; import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; @@ -45,91 +40,101 @@ import java.util.concurrent.LinkedBlockingQueue; * @author Xylitol */ public class CalculateAverage_C5H12O5 { - private static final int BUFFER_CAPACITY = 1024 * 1024; + private static final int BUFFER_CAPACITY = 1024 * 1024 * 10; private static final int MAP_CAPACITY = 10000; - private static final int QUEUE_CAPACITY = 2; + private static final int PROCESSORS = Runtime.getRuntime().availableProcessors(); + private static final BlockingQueue BYTES_QUEUE = new LinkedBlockingQueue<>(PROCESSORS); + private static long readPosition; public static void main(String[] args) throws Exception { - // Files.list(Paths.get("./src/test/resources/samples")) - // .filter(file -> file.toString().endsWith(".txt")) - // .forEach(file -> { - // try { - // String actual = calc(file); - // String expected = Files.readAllLines(Paths.get(file.toString().replace(".txt", ".out"))).get(0); - // System.out.println(file.getFileName() + ": " + expected.equals(actual)); - // } catch (Exception e) { - // System.out.println(file.getFileName() + ": " + false); - // e.printStackTrace(); - // } - // }); - // long start = System.currentTimeMillis(); - System.out.println(calc(Paths.get("./measurements.txt"))); - // System.out.println("Time: " + (System.currentTimeMillis() - start) + "ms"); + System.out.println(calc("./measurements.txt")); } /** * Calculate the average. */ - public static String calc(Path file) throws IOException, ExecutionException, InterruptedException { - long[] positions = fragment(file, Runtime.getRuntime().availableProcessors()); - FutureTask>[] tasks = new FutureTask[positions.length]; - for (int i = 0; i < positions.length; i++) { - tasks[i] = new FutureTask<>(new Task(file, (i == 0 ? 0 : positions[i - 1] + 1), positions[i])); - new Thread(tasks[i]).start(); - } + public static String calc(String path) throws IOException, ExecutionException, InterruptedException { + readPosition = 0; Map result = HashMap.newHashMap(MAP_CAPACITY); - for (FutureTask> task : tasks) { - task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge)); + // read and offer to queue + try (AsynchronousFileChannel channel = AsynchronousFileChannel.open( + Paths.get(path), Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor())) { + ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY); + channel.read(buffer, readPosition, buffer, new CompletionHandler<>() { + @Override + public void completed(Integer bytesRead, ByteBuffer buffer) { + try { + if (bytesRead > 0) { + for (int i = buffer.position() - 1; i >= 0; i--) { + if (buffer.get(i) == '\n') { + buffer.limit(i + 1); + break; + } + } + buffer.flip(); + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + readPosition += buffer.limit(); + BYTES_QUEUE.put(bytes); + buffer.clear(); + channel.read(buffer, readPosition, buffer, this); + } + else { + for (int i = 0; i < PROCESSORS; i++) { + BYTES_QUEUE.put(new byte[0]); + } + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + @Override + public void failed(Throwable exc, ByteBuffer buffer) { + // ignore + } + }); + + @SuppressWarnings("unchecked") + FutureTask>[] tasks = new FutureTask[PROCESSORS]; + for (int i = 0; i < PROCESSORS; i++) { + tasks[i] = new FutureTask<>(new Task()); + new Thread(tasks[i]).start(); + } + for (FutureTask> task : tasks) { + task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge)); + } } return new TreeMap<>(result).toString(); } - /** - * Fragment the file into chunks. - */ - private static long[] fragment(Path filePath, int chunkNum) throws IOException { - long fileSize = Files.size(filePath); - long chunkSize = fileSize / chunkNum; - long[] positions = new long[chunkNum]; - try (RandomAccessFile file = new RandomAccessFile(filePath.toFile(), "r")) { - long position = chunkSize; - for (int i = 0; i < chunkNum - 1; i++) { - if (position >= fileSize) { - break; - } - file.seek(position); - while (file.read() != '\n') { - position++; - } - positions[i] = position; - position += chunkSize; - } - } - positions[chunkNum - 1] = fileSize; - return Arrays.stream(positions).filter(value -> value != 0).toArray(); - } - /** * The measurement name. */ - private record MeasurementName(byte[] bytes) { + private record MeasurementName(byte[] bytes, int length) { @Override - public boolean equals(Object other) { - if (!(other instanceof MeasurementName)) { + public boolean equals(Object name) { + MeasurementName other = (MeasurementName) name; + if (other.length != length) { return false; } - return Arrays.equals(bytes, ((MeasurementName) other).bytes); + return Arrays.compare(bytes, 0, length, other.bytes, 0, length) == 0; } @Override public int hashCode() { - return Arrays.hashCode(bytes); + int result = 1; + for (int i = 0; i < length; i++) { + result = 31 * result + bytes[i]; + } + return result; } @Override public String toString() { - return new String(bytes, StandardCharsets.UTF_8); + return new String(bytes, 0, length, StandardCharsets.UTF_8); } } @@ -168,116 +173,53 @@ public class CalculateAverage_C5H12O5 { } /** - * The task to read and calculate. + * The task to calculate. */ private static class Task implements Callable> { - private final Path file; - private long readPosition; - private long calcPosition; - private final long limitSize; - private final BlockingQueue bytesQueue = new LinkedBlockingQueue<>(QUEUE_CAPACITY); - - public Task(Path file, long position, long limitSize) { - this.file = file; - this.readPosition = position; - this.calcPosition = position; - this.limitSize = limitSize; - } @Override - public Map call() throws IOException { - // read and offer to queue - AsynchronousFileChannel channel = AsynchronousFileChannel.open( - file, Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor()); - ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY); - channel.read(buffer, readPosition, buffer, new CompletionHandler<>() { - @Override - public void completed(Integer bytesRead, ByteBuffer buffer) { - if (bytesRead > 0 && readPosition < limitSize) { - try { - buffer.flip(); - byte[] bytes = new byte[buffer.remaining()]; - buffer.get(bytes); - readPosition += bytesRead; - if (readPosition > limitSize) { - int diff = (int) (readPosition - limitSize); - byte[] newBytes = new byte[bytes.length - diff]; - System.arraycopy(bytes, 0, newBytes, 0, newBytes.length); - bytesQueue.put(newBytes); - } - else { - bytesQueue.put(bytes); - buffer.clear(); - channel.read(buffer, readPosition, buffer, this); - } - } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - } - } - - @Override - public void failed(Throwable exc, ByteBuffer buffer) { - // ignore - } - }); - + public Map call() throws InterruptedException { // poll from queue and calculate Map result = HashMap.newHashMap(MAP_CAPACITY); - byte[] readBytes = null; - byte[] remaining = null; - while (calcPosition < limitSize) { - readBytes = bytesQueue.poll(); - if (readBytes != null) { - List lines = split(readBytes, (byte) '\n'); - for (int i = 0; i < lines.size(); i++) { - byte[] lineBytes = lines.get(i); - if (i == 0 && remaining != null) { - byte[] newBytes = new byte[remaining.length + lineBytes.length]; - System.arraycopy(remaining, 0, newBytes, 0, remaining.length); - System.arraycopy(lineBytes, 0, newBytes, remaining.length, lineBytes.length); - lineBytes = newBytes; + for (byte[] bytes = BYTES_QUEUE.take(); true; bytes = BYTES_QUEUE.take()) { + if (bytes.length == 0) { + break; + } + int start = 0; + for (int end = 0; end < bytes.length; end++) { + if (bytes[end] == '\n') { + byte[] newBytes = new byte[end - start]; + System.arraycopy(bytes, start, newBytes, 0, newBytes.length); + int semicolon = newBytes.length - 4; + for (; semicolon >= 0; semicolon--) { + if (newBytes[semicolon] == ';') { + break; + } } - if (i == lines.size() - 1) { - remaining = lineBytes; - break; + MeasurementName station = new MeasurementName(newBytes, semicolon); + int value = toInt(newBytes, semicolon + 1); + MeasurementData data = result.get(station); + if (data != null) { + data.merge(value, value, value, 1); } - agg(result, lineBytes); + else { + result.put(station, new MeasurementData(value)); + } + start = end + 1; } - calcPosition += readBytes.length; } } - if (remaining != null && remaining.length > 0) { - agg(result, remaining); - } - channel.close(); return result; } - /** - * Aggregate the measurement data. - */ - private static void agg(Map result, byte[] bytes) { - List parts = split(bytes, (byte) ';'); - MeasurementName station = new MeasurementName(parts.getFirst()); - int value = toInt(parts.getLast()); - MeasurementData data = result.get(station); - if (data != null) { - data.merge(value, value, value, 1); - } - else { - result.put(station, new MeasurementData(value)); - } - } - /** * Convert the byte array to int. */ - private static int toInt(byte[] bytes) { + private static int toInt(byte[] bytes, int start) { boolean negative = false; int result = 0; - for (byte b : bytes) { + for (int i = start; i < bytes.length; i++) { + byte b = bytes[i]; if (b == '-') { negative = true; continue; @@ -288,27 +230,5 @@ public class CalculateAverage_C5H12O5 { } return negative ? -result : result; } - - /** - * Split the byte array by given byte. - */ - private static List split(byte[] bytes, byte separator) { - List result = new ArrayList<>(); - int start = 0; - for (int end = 0; end < bytes.length; end++) { - if (bytes[end] == separator) { - byte[] newBytes = new byte[end - start]; - System.arraycopy(bytes, start, newBytes, 0, newBytes.length); - result.add(newBytes); - start = end + 1; - } - } - if (start <= bytes.length) { - byte[] newBytes = new byte[bytes.length - start]; - System.arraycopy(bytes, start, newBytes, 0, newBytes.length); - result.add(newBytes); - } - return result; - } } }