diff --git a/calculate_average_C5H12O5.sh b/calculate_average_C5H12O5.sh new file mode 100755 index 0000000..64c6c76 --- /dev/null +++ b/calculate_average_C5H12O5.sh @@ -0,0 +1,19 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +JAVA_OPTS="--enable-preview" +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_C5H12O5 diff --git a/prepare_C5H12O5.sh b/prepare_C5H12O5.sh new file mode 100755 index 0000000..06b81c4 --- /dev/null +++ b/prepare_C5H12O5.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Uncomment below to use sdk +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java b/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java new file mode 100644 index 0000000..0764b65 --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java @@ -0,0 +1,314 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousFileChannel; +import java.nio.channels.CompletionHandler; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.FutureTask; +import java.util.concurrent.LinkedBlockingQueue; + +/** + * Calculates the average using AIO and multiple threads. + * + * @author Xylitol + */ +public class CalculateAverage_C5H12O5 { + private static final int BUFFER_CAPACITY = 1024 * 1024; + private static final int MAP_CAPACITY = 10000; + private static final int QUEUE_CAPACITY = 2; + + public static void main(String[] args) throws Exception { + // Files.list(Paths.get("./src/test/resources/samples")) + // .filter(file -> file.toString().endsWith(".txt")) + // .forEach(file -> { + // try { + // String actual = calc(file); + // String expected = Files.readAllLines(Paths.get(file.toString().replace(".txt", ".out"))).get(0); + // System.out.println(file.getFileName() + ": " + expected.equals(actual)); + // } catch (Exception e) { + // System.out.println(file.getFileName() + ": " + false); + // e.printStackTrace(); + // } + // }); + // long start = System.currentTimeMillis(); + System.out.println(calc(Paths.get("./measurements.txt"))); + // System.out.println("Time: " + (System.currentTimeMillis() - start) + "ms"); + } + + /** + * Calculate the average. + */ + public static String calc(Path file) throws IOException, ExecutionException, InterruptedException { + long[] positions = fragment(file, Runtime.getRuntime().availableProcessors()); + FutureTask>[] tasks = new FutureTask[positions.length]; + for (int i = 0; i < positions.length; i++) { + tasks[i] = new FutureTask<>(new Task(file, (i == 0 ? 0 : positions[i - 1] + 1), positions[i])); + new Thread(tasks[i]).start(); + } + Map result = HashMap.newHashMap(MAP_CAPACITY); + for (FutureTask> task : tasks) { + task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge)); + } + return new TreeMap<>(result).toString(); + } + + /** + * Fragment the file into chunks. + */ + private static long[] fragment(Path filePath, int chunkNum) throws IOException { + long fileSize = Files.size(filePath); + long chunkSize = fileSize / chunkNum; + long[] positions = new long[chunkNum]; + try (RandomAccessFile file = new RandomAccessFile(filePath.toFile(), "r")) { + long position = chunkSize; + for (int i = 0; i < chunkNum - 1; i++) { + if (position >= fileSize) { + break; + } + file.seek(position); + while (file.read() != '\n') { + position++; + } + positions[i] = position; + position += chunkSize; + } + } + positions[chunkNum - 1] = fileSize; + return Arrays.stream(positions).filter(value -> value != 0).toArray(); + } + + /** + * The measurement name. + */ + private record MeasurementName(byte[] bytes) { + + @Override + public boolean equals(Object other) { + if (!(other instanceof MeasurementName)) { + return false; + } + return Arrays.equals(bytes, ((MeasurementName) other).bytes); + } + + @Override + public int hashCode() { + return Arrays.hashCode(bytes); + } + + @Override + public String toString() { + return new String(bytes, StandardCharsets.UTF_8); + } + } + + /** + * The measurement data. + */ + private static class MeasurementData { + private int min; + private int max; + private int sum; + private int count; + + public MeasurementData(int value) { + this.min = value; + this.max = value; + this.sum = value; + this.count = 1; + } + + public MeasurementData merge(MeasurementData data) { + return merge(data.min, data.max, data.sum, data.count); + } + + public MeasurementData merge(int min, int max, int sum, int count) { + this.min = Math.min(this.min, min); + this.max = Math.max(this.max, max); + this.sum += sum; + this.count += count; + return this; + } + + @Override + public String toString() { + return (min / 10.0) + "/" + (Math.round((double) sum / count) / 10.0) + "/" + (max / 10.0); + } + } + + /** + * The task to read and calculate. + */ + private static class Task implements Callable> { + private final Path file; + private long readPosition; + private long calcPosition; + private final long limitSize; + private final BlockingQueue bytesQueue = new LinkedBlockingQueue<>(QUEUE_CAPACITY); + + public Task(Path file, long position, long limitSize) { + this.file = file; + this.readPosition = position; + this.calcPosition = position; + this.limitSize = limitSize; + } + + @Override + public Map call() throws IOException { + // read and offer to queue + AsynchronousFileChannel channel = AsynchronousFileChannel.open( + file, Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor()); + ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY); + channel.read(buffer, readPosition, buffer, new CompletionHandler<>() { + @Override + public void completed(Integer bytesRead, ByteBuffer buffer) { + if (bytesRead > 0 && readPosition < limitSize) { + try { + buffer.flip(); + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + readPosition += bytesRead; + if (readPosition > limitSize) { + int diff = (int) (readPosition - limitSize); + byte[] newBytes = new byte[bytes.length - diff]; + System.arraycopy(bytes, 0, newBytes, 0, newBytes.length); + bytesQueue.put(newBytes); + } + else { + bytesQueue.put(bytes); + buffer.clear(); + channel.read(buffer, readPosition, buffer, this); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } + + @Override + public void failed(Throwable exc, ByteBuffer buffer) { + // ignore + } + }); + + // poll from queue and calculate + Map result = HashMap.newHashMap(MAP_CAPACITY); + byte[] readBytes = null; + byte[] remaining = null; + while (calcPosition < limitSize) { + readBytes = bytesQueue.poll(); + if (readBytes != null) { + List lines = split(readBytes, (byte) '\n'); + for (int i = 0; i < lines.size(); i++) { + byte[] lineBytes = lines.get(i); + if (i == 0 && remaining != null) { + byte[] newBytes = new byte[remaining.length + lineBytes.length]; + System.arraycopy(remaining, 0, newBytes, 0, remaining.length); + System.arraycopy(lineBytes, 0, newBytes, remaining.length, lineBytes.length); + lineBytes = newBytes; + } + if (i == lines.size() - 1) { + remaining = lineBytes; + break; + } + agg(result, lineBytes); + } + calcPosition += readBytes.length; + } + } + if (remaining != null && remaining.length > 0) { + agg(result, remaining); + } + channel.close(); + return result; + } + + /** + * Aggregate the measurement data. + */ + private static void agg(Map result, byte[] bytes) { + List parts = split(bytes, (byte) ';'); + MeasurementName station = new MeasurementName(parts.getFirst()); + int value = toInt(parts.getLast()); + MeasurementData data = result.get(station); + if (data != null) { + data.merge(value, value, value, 1); + } + else { + result.put(station, new MeasurementData(value)); + } + } + + /** + * Convert the byte array to int. + */ + private static int toInt(byte[] bytes) { + boolean negative = false; + int result = 0; + for (byte b : bytes) { + if (b == '-') { + negative = true; + continue; + } + if (b != '.') { + result = result * 10 + (b - '0'); + } + } + return negative ? -result : result; + } + + /** + * Split the byte array by given byte. + */ + private static List split(byte[] bytes, byte separator) { + List result = new ArrayList<>(); + int start = 0; + for (int end = 0; end < bytes.length; end++) { + if (bytes[end] == separator) { + byte[] newBytes = new byte[end - start]; + System.arraycopy(bytes, start, newBytes, 0, newBytes.length); + result.add(newBytes); + start = end + 1; + } + } + if (start <= bytes.length) { + byte[] newBytes = new byte[bytes.length - start]; + System.arraycopy(bytes, start, newBytes, 0, newBytes.length); + result.add(newBytes); + } + return result; + } + } +}