Improved my implementation

This commit is contained in:
Xylitol 2024-01-12 22:18:15 +08:00 committed by Gunnar Morling
parent 09e0311e09
commit 9e5ec51315

View File

@ -16,19 +16,14 @@
package dev.morling.onebrc; package dev.morling.onebrc;
import java.io.IOException; import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousFileChannel; import java.nio.channels.AsynchronousFileChannel;
import java.nio.channels.CompletionHandler; import java.nio.channels.CompletionHandler;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.nio.file.StandardOpenOption; import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.TreeMap; import java.util.TreeMap;
@ -45,91 +40,101 @@ import java.util.concurrent.LinkedBlockingQueue;
* @author Xylitol * @author Xylitol
*/ */
public class CalculateAverage_C5H12O5 { public class CalculateAverage_C5H12O5 {
private static final int BUFFER_CAPACITY = 1024 * 1024; private static final int BUFFER_CAPACITY = 1024 * 1024 * 10;
private static final int MAP_CAPACITY = 10000; private static final int MAP_CAPACITY = 10000;
private static final int QUEUE_CAPACITY = 2; private static final int PROCESSORS = Runtime.getRuntime().availableProcessors();
private static final BlockingQueue<byte[]> BYTES_QUEUE = new LinkedBlockingQueue<>(PROCESSORS);
private static long readPosition;
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
// Files.list(Paths.get("./src/test/resources/samples")) System.out.println(calc("./measurements.txt"));
// .filter(file -> file.toString().endsWith(".txt"))
// .forEach(file -> {
// try {
// String actual = calc(file);
// String expected = Files.readAllLines(Paths.get(file.toString().replace(".txt", ".out"))).get(0);
// System.out.println(file.getFileName() + ": " + expected.equals(actual));
// } catch (Exception e) {
// System.out.println(file.getFileName() + ": " + false);
// e.printStackTrace();
// }
// });
// long start = System.currentTimeMillis();
System.out.println(calc(Paths.get("./measurements.txt")));
// System.out.println("Time: " + (System.currentTimeMillis() - start) + "ms");
} }
/** /**
* Calculate the average. * Calculate the average.
*/ */
public static String calc(Path file) throws IOException, ExecutionException, InterruptedException { public static String calc(String path) throws IOException, ExecutionException, InterruptedException {
long[] positions = fragment(file, Runtime.getRuntime().availableProcessors()); readPosition = 0;
FutureTask<Map<MeasurementName, MeasurementData>>[] tasks = new FutureTask[positions.length]; Map<String, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
for (int i = 0; i < positions.length; i++) { // read and offer to queue
tasks[i] = new FutureTask<>(new Task(file, (i == 0 ? 0 : positions[i - 1] + 1), positions[i])); try (AsynchronousFileChannel channel = AsynchronousFileChannel.open(
Paths.get(path), Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor())) {
ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY);
channel.read(buffer, readPosition, buffer, new CompletionHandler<>() {
@Override
public void completed(Integer bytesRead, ByteBuffer buffer) {
try {
if (bytesRead > 0) {
for (int i = buffer.position() - 1; i >= 0; i--) {
if (buffer.get(i) == '\n') {
buffer.limit(i + 1);
break;
}
}
buffer.flip();
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);
readPosition += buffer.limit();
BYTES_QUEUE.put(bytes);
buffer.clear();
channel.read(buffer, readPosition, buffer, this);
}
else {
for (int i = 0; i < PROCESSORS; i++) {
BYTES_QUEUE.put(new byte[0]);
}
}
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
@Override
public void failed(Throwable exc, ByteBuffer buffer) {
// ignore
}
});
@SuppressWarnings("unchecked")
FutureTask<Map<MeasurementName, MeasurementData>>[] tasks = new FutureTask[PROCESSORS];
for (int i = 0; i < PROCESSORS; i++) {
tasks[i] = new FutureTask<>(new Task());
new Thread(tasks[i]).start(); new Thread(tasks[i]).start();
} }
Map<String, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
for (FutureTask<Map<MeasurementName, MeasurementData>> task : tasks) { for (FutureTask<Map<MeasurementName, MeasurementData>> task : tasks) {
task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge)); task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge));
} }
}
return new TreeMap<>(result).toString(); return new TreeMap<>(result).toString();
} }
/**
* Fragment the file into chunks.
*/
private static long[] fragment(Path filePath, int chunkNum) throws IOException {
long fileSize = Files.size(filePath);
long chunkSize = fileSize / chunkNum;
long[] positions = new long[chunkNum];
try (RandomAccessFile file = new RandomAccessFile(filePath.toFile(), "r")) {
long position = chunkSize;
for (int i = 0; i < chunkNum - 1; i++) {
if (position >= fileSize) {
break;
}
file.seek(position);
while (file.read() != '\n') {
position++;
}
positions[i] = position;
position += chunkSize;
}
}
positions[chunkNum - 1] = fileSize;
return Arrays.stream(positions).filter(value -> value != 0).toArray();
}
/** /**
* The measurement name. * The measurement name.
*/ */
private record MeasurementName(byte[] bytes) { private record MeasurementName(byte[] bytes, int length) {
@Override @Override
public boolean equals(Object other) { public boolean equals(Object name) {
if (!(other instanceof MeasurementName)) { MeasurementName other = (MeasurementName) name;
if (other.length != length) {
return false; return false;
} }
return Arrays.equals(bytes, ((MeasurementName) other).bytes); return Arrays.compare(bytes, 0, length, other.bytes, 0, length) == 0;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Arrays.hashCode(bytes); int result = 1;
for (int i = 0; i < length; i++) {
result = 31 * result + bytes[i];
}
return result;
} }
@Override @Override
public String toString() { public String toString() {
return new String(bytes, StandardCharsets.UTF_8); return new String(bytes, 0, length, StandardCharsets.UTF_8);
} }
} }
@ -168,100 +173,31 @@ public class CalculateAverage_C5H12O5 {
} }
/** /**
* The task to read and calculate. * The task to calculate.
*/ */
private static class Task implements Callable<Map<MeasurementName, MeasurementData>> { private static class Task implements Callable<Map<MeasurementName, MeasurementData>> {
private final Path file;
private long readPosition;
private long calcPosition;
private final long limitSize;
private final BlockingQueue<byte[]> bytesQueue = new LinkedBlockingQueue<>(QUEUE_CAPACITY);
public Task(Path file, long position, long limitSize) {
this.file = file;
this.readPosition = position;
this.calcPosition = position;
this.limitSize = limitSize;
}
@Override @Override
public Map<MeasurementName, MeasurementData> call() throws IOException { public Map<MeasurementName, MeasurementData> call() throws InterruptedException {
// read and offer to queue
AsynchronousFileChannel channel = AsynchronousFileChannel.open(
file, Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor());
ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY);
channel.read(buffer, readPosition, buffer, new CompletionHandler<>() {
@Override
public void completed(Integer bytesRead, ByteBuffer buffer) {
if (bytesRead > 0 && readPosition < limitSize) {
try {
buffer.flip();
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);
readPosition += bytesRead;
if (readPosition > limitSize) {
int diff = (int) (readPosition - limitSize);
byte[] newBytes = new byte[bytes.length - diff];
System.arraycopy(bytes, 0, newBytes, 0, newBytes.length);
bytesQueue.put(newBytes);
}
else {
bytesQueue.put(bytes);
buffer.clear();
channel.read(buffer, readPosition, buffer, this);
}
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
@Override
public void failed(Throwable exc, ByteBuffer buffer) {
// ignore
}
});
// poll from queue and calculate // poll from queue and calculate
Map<MeasurementName, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY); Map<MeasurementName, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
byte[] readBytes = null; for (byte[] bytes = BYTES_QUEUE.take(); true; bytes = BYTES_QUEUE.take()) {
byte[] remaining = null; if (bytes.length == 0) {
while (calcPosition < limitSize) {
readBytes = bytesQueue.poll();
if (readBytes != null) {
List<byte[]> lines = split(readBytes, (byte) '\n');
for (int i = 0; i < lines.size(); i++) {
byte[] lineBytes = lines.get(i);
if (i == 0 && remaining != null) {
byte[] newBytes = new byte[remaining.length + lineBytes.length];
System.arraycopy(remaining, 0, newBytes, 0, remaining.length);
System.arraycopy(lineBytes, 0, newBytes, remaining.length, lineBytes.length);
lineBytes = newBytes;
}
if (i == lines.size() - 1) {
remaining = lineBytes;
break; break;
} }
agg(result, lineBytes); int start = 0;
} for (int end = 0; end < bytes.length; end++) {
calcPosition += readBytes.length; if (bytes[end] == '\n') {
byte[] newBytes = new byte[end - start];
System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
int semicolon = newBytes.length - 4;
for (; semicolon >= 0; semicolon--) {
if (newBytes[semicolon] == ';') {
break;
} }
} }
if (remaining != null && remaining.length > 0) { MeasurementName station = new MeasurementName(newBytes, semicolon);
agg(result, remaining); int value = toInt(newBytes, semicolon + 1);
}
channel.close();
return result;
}
/**
* Aggregate the measurement data.
*/
private static void agg(Map<MeasurementName, MeasurementData> result, byte[] bytes) {
List<byte[]> parts = split(bytes, (byte) ';');
MeasurementName station = new MeasurementName(parts.getFirst());
int value = toInt(parts.getLast());
MeasurementData data = result.get(station); MeasurementData data = result.get(station);
if (data != null) { if (data != null) {
data.merge(value, value, value, 1); data.merge(value, value, value, 1);
@ -269,15 +205,21 @@ public class CalculateAverage_C5H12O5 {
else { else {
result.put(station, new MeasurementData(value)); result.put(station, new MeasurementData(value));
} }
start = end + 1;
}
}
}
return result;
} }
/** /**
* Convert the byte array to int. * Convert the byte array to int.
*/ */
private static int toInt(byte[] bytes) { private static int toInt(byte[] bytes, int start) {
boolean negative = false; boolean negative = false;
int result = 0; int result = 0;
for (byte b : bytes) { for (int i = start; i < bytes.length; i++) {
byte b = bytes[i];
if (b == '-') { if (b == '-') {
negative = true; negative = true;
continue; continue;
@ -288,27 +230,5 @@ public class CalculateAverage_C5H12O5 {
} }
return negative ? -result : result; return negative ? -result : result;
} }
/**
* Split the byte array by given byte.
*/
private static List<byte[]> split(byte[] bytes, byte separator) {
List<byte[]> result = new ArrayList<>();
int start = 0;
for (int end = 0; end < bytes.length; end++) {
if (bytes[end] == separator) {
byte[] newBytes = new byte[end - start];
System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
result.add(newBytes);
start = end + 1;
}
}
if (start <= bytes.length) {
byte[] newBytes = new byte[bytes.length - start];
System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
result.add(newBytes);
}
return result;
}
} }
} }