Improved my implementation
This commit is contained in:
parent
09e0311e09
commit
9e5ec51315
@ -16,19 +16,14 @@
|
|||||||
package dev.morling.onebrc;
|
package dev.morling.onebrc;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.RandomAccessFile;
|
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.channels.AsynchronousFileChannel;
|
import java.nio.channels.AsynchronousFileChannel;
|
||||||
import java.nio.channels.CompletionHandler;
|
import java.nio.channels.CompletionHandler;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.nio.file.Path;
|
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.nio.file.StandardOpenOption;
|
import java.nio.file.StandardOpenOption;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
@ -45,91 +40,101 @@ import java.util.concurrent.LinkedBlockingQueue;
|
|||||||
* @author Xylitol
|
* @author Xylitol
|
||||||
*/
|
*/
|
||||||
public class CalculateAverage_C5H12O5 {
|
public class CalculateAverage_C5H12O5 {
|
||||||
private static final int BUFFER_CAPACITY = 1024 * 1024;
|
private static final int BUFFER_CAPACITY = 1024 * 1024 * 10;
|
||||||
private static final int MAP_CAPACITY = 10000;
|
private static final int MAP_CAPACITY = 10000;
|
||||||
private static final int QUEUE_CAPACITY = 2;
|
private static final int PROCESSORS = Runtime.getRuntime().availableProcessors();
|
||||||
|
private static final BlockingQueue<byte[]> BYTES_QUEUE = new LinkedBlockingQueue<>(PROCESSORS);
|
||||||
|
private static long readPosition;
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
// Files.list(Paths.get("./src/test/resources/samples"))
|
System.out.println(calc("./measurements.txt"));
|
||||||
// .filter(file -> file.toString().endsWith(".txt"))
|
|
||||||
// .forEach(file -> {
|
|
||||||
// try {
|
|
||||||
// String actual = calc(file);
|
|
||||||
// String expected = Files.readAllLines(Paths.get(file.toString().replace(".txt", ".out"))).get(0);
|
|
||||||
// System.out.println(file.getFileName() + ": " + expected.equals(actual));
|
|
||||||
// } catch (Exception e) {
|
|
||||||
// System.out.println(file.getFileName() + ": " + false);
|
|
||||||
// e.printStackTrace();
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
// long start = System.currentTimeMillis();
|
|
||||||
System.out.println(calc(Paths.get("./measurements.txt")));
|
|
||||||
// System.out.println("Time: " + (System.currentTimeMillis() - start) + "ms");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the average.
|
* Calculate the average.
|
||||||
*/
|
*/
|
||||||
public static String calc(Path file) throws IOException, ExecutionException, InterruptedException {
|
public static String calc(String path) throws IOException, ExecutionException, InterruptedException {
|
||||||
long[] positions = fragment(file, Runtime.getRuntime().availableProcessors());
|
readPosition = 0;
|
||||||
FutureTask<Map<MeasurementName, MeasurementData>>[] tasks = new FutureTask[positions.length];
|
Map<String, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
|
||||||
for (int i = 0; i < positions.length; i++) {
|
// read and offer to queue
|
||||||
tasks[i] = new FutureTask<>(new Task(file, (i == 0 ? 0 : positions[i - 1] + 1), positions[i]));
|
try (AsynchronousFileChannel channel = AsynchronousFileChannel.open(
|
||||||
|
Paths.get(path), Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor())) {
|
||||||
|
ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY);
|
||||||
|
channel.read(buffer, readPosition, buffer, new CompletionHandler<>() {
|
||||||
|
@Override
|
||||||
|
public void completed(Integer bytesRead, ByteBuffer buffer) {
|
||||||
|
try {
|
||||||
|
if (bytesRead > 0) {
|
||||||
|
for (int i = buffer.position() - 1; i >= 0; i--) {
|
||||||
|
if (buffer.get(i) == '\n') {
|
||||||
|
buffer.limit(i + 1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buffer.flip();
|
||||||
|
byte[] bytes = new byte[buffer.remaining()];
|
||||||
|
buffer.get(bytes);
|
||||||
|
readPosition += buffer.limit();
|
||||||
|
BYTES_QUEUE.put(bytes);
|
||||||
|
buffer.clear();
|
||||||
|
channel.read(buffer, readPosition, buffer, this);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int i = 0; i < PROCESSORS; i++) {
|
||||||
|
BYTES_QUEUE.put(new byte[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (InterruptedException e) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void failed(Throwable exc, ByteBuffer buffer) {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
FutureTask<Map<MeasurementName, MeasurementData>>[] tasks = new FutureTask[PROCESSORS];
|
||||||
|
for (int i = 0; i < PROCESSORS; i++) {
|
||||||
|
tasks[i] = new FutureTask<>(new Task());
|
||||||
new Thread(tasks[i]).start();
|
new Thread(tasks[i]).start();
|
||||||
}
|
}
|
||||||
Map<String, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
|
|
||||||
for (FutureTask<Map<MeasurementName, MeasurementData>> task : tasks) {
|
for (FutureTask<Map<MeasurementName, MeasurementData>> task : tasks) {
|
||||||
task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge));
|
task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return new TreeMap<>(result).toString();
|
return new TreeMap<>(result).toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Fragment the file into chunks.
|
|
||||||
*/
|
|
||||||
private static long[] fragment(Path filePath, int chunkNum) throws IOException {
|
|
||||||
long fileSize = Files.size(filePath);
|
|
||||||
long chunkSize = fileSize / chunkNum;
|
|
||||||
long[] positions = new long[chunkNum];
|
|
||||||
try (RandomAccessFile file = new RandomAccessFile(filePath.toFile(), "r")) {
|
|
||||||
long position = chunkSize;
|
|
||||||
for (int i = 0; i < chunkNum - 1; i++) {
|
|
||||||
if (position >= fileSize) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
file.seek(position);
|
|
||||||
while (file.read() != '\n') {
|
|
||||||
position++;
|
|
||||||
}
|
|
||||||
positions[i] = position;
|
|
||||||
position += chunkSize;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
positions[chunkNum - 1] = fileSize;
|
|
||||||
return Arrays.stream(positions).filter(value -> value != 0).toArray();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The measurement name.
|
* The measurement name.
|
||||||
*/
|
*/
|
||||||
private record MeasurementName(byte[] bytes) {
|
private record MeasurementName(byte[] bytes, int length) {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object other) {
|
public boolean equals(Object name) {
|
||||||
if (!(other instanceof MeasurementName)) {
|
MeasurementName other = (MeasurementName) name;
|
||||||
|
if (other.length != length) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return Arrays.equals(bytes, ((MeasurementName) other).bytes);
|
return Arrays.compare(bytes, 0, length, other.bytes, 0, length) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Arrays.hashCode(bytes);
|
int result = 1;
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
result = 31 * result + bytes[i];
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return new String(bytes, StandardCharsets.UTF_8);
|
return new String(bytes, 0, length, StandardCharsets.UTF_8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,100 +173,31 @@ public class CalculateAverage_C5H12O5 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The task to read and calculate.
|
* The task to calculate.
|
||||||
*/
|
*/
|
||||||
private static class Task implements Callable<Map<MeasurementName, MeasurementData>> {
|
private static class Task implements Callable<Map<MeasurementName, MeasurementData>> {
|
||||||
private final Path file;
|
|
||||||
private long readPosition;
|
|
||||||
private long calcPosition;
|
|
||||||
private final long limitSize;
|
|
||||||
private final BlockingQueue<byte[]> bytesQueue = new LinkedBlockingQueue<>(QUEUE_CAPACITY);
|
|
||||||
|
|
||||||
public Task(Path file, long position, long limitSize) {
|
|
||||||
this.file = file;
|
|
||||||
this.readPosition = position;
|
|
||||||
this.calcPosition = position;
|
|
||||||
this.limitSize = limitSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MeasurementName, MeasurementData> call() throws IOException {
|
public Map<MeasurementName, MeasurementData> call() throws InterruptedException {
|
||||||
// read and offer to queue
|
|
||||||
AsynchronousFileChannel channel = AsynchronousFileChannel.open(
|
|
||||||
file, Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor());
|
|
||||||
ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY);
|
|
||||||
channel.read(buffer, readPosition, buffer, new CompletionHandler<>() {
|
|
||||||
@Override
|
|
||||||
public void completed(Integer bytesRead, ByteBuffer buffer) {
|
|
||||||
if (bytesRead > 0 && readPosition < limitSize) {
|
|
||||||
try {
|
|
||||||
buffer.flip();
|
|
||||||
byte[] bytes = new byte[buffer.remaining()];
|
|
||||||
buffer.get(bytes);
|
|
||||||
readPosition += bytesRead;
|
|
||||||
if (readPosition > limitSize) {
|
|
||||||
int diff = (int) (readPosition - limitSize);
|
|
||||||
byte[] newBytes = new byte[bytes.length - diff];
|
|
||||||
System.arraycopy(bytes, 0, newBytes, 0, newBytes.length);
|
|
||||||
bytesQueue.put(newBytes);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
bytesQueue.put(bytes);
|
|
||||||
buffer.clear();
|
|
||||||
channel.read(buffer, readPosition, buffer, this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (InterruptedException e) {
|
|
||||||
Thread.currentThread().interrupt();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void failed(Throwable exc, ByteBuffer buffer) {
|
|
||||||
// ignore
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// poll from queue and calculate
|
// poll from queue and calculate
|
||||||
Map<MeasurementName, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
|
Map<MeasurementName, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
|
||||||
byte[] readBytes = null;
|
for (byte[] bytes = BYTES_QUEUE.take(); true; bytes = BYTES_QUEUE.take()) {
|
||||||
byte[] remaining = null;
|
if (bytes.length == 0) {
|
||||||
while (calcPosition < limitSize) {
|
|
||||||
readBytes = bytesQueue.poll();
|
|
||||||
if (readBytes != null) {
|
|
||||||
List<byte[]> lines = split(readBytes, (byte) '\n');
|
|
||||||
for (int i = 0; i < lines.size(); i++) {
|
|
||||||
byte[] lineBytes = lines.get(i);
|
|
||||||
if (i == 0 && remaining != null) {
|
|
||||||
byte[] newBytes = new byte[remaining.length + lineBytes.length];
|
|
||||||
System.arraycopy(remaining, 0, newBytes, 0, remaining.length);
|
|
||||||
System.arraycopy(lineBytes, 0, newBytes, remaining.length, lineBytes.length);
|
|
||||||
lineBytes = newBytes;
|
|
||||||
}
|
|
||||||
if (i == lines.size() - 1) {
|
|
||||||
remaining = lineBytes;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
agg(result, lineBytes);
|
int start = 0;
|
||||||
}
|
for (int end = 0; end < bytes.length; end++) {
|
||||||
calcPosition += readBytes.length;
|
if (bytes[end] == '\n') {
|
||||||
|
byte[] newBytes = new byte[end - start];
|
||||||
|
System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
|
||||||
|
int semicolon = newBytes.length - 4;
|
||||||
|
for (; semicolon >= 0; semicolon--) {
|
||||||
|
if (newBytes[semicolon] == ';') {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (remaining != null && remaining.length > 0) {
|
MeasurementName station = new MeasurementName(newBytes, semicolon);
|
||||||
agg(result, remaining);
|
int value = toInt(newBytes, semicolon + 1);
|
||||||
}
|
|
||||||
channel.close();
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Aggregate the measurement data.
|
|
||||||
*/
|
|
||||||
private static void agg(Map<MeasurementName, MeasurementData> result, byte[] bytes) {
|
|
||||||
List<byte[]> parts = split(bytes, (byte) ';');
|
|
||||||
MeasurementName station = new MeasurementName(parts.getFirst());
|
|
||||||
int value = toInt(parts.getLast());
|
|
||||||
MeasurementData data = result.get(station);
|
MeasurementData data = result.get(station);
|
||||||
if (data != null) {
|
if (data != null) {
|
||||||
data.merge(value, value, value, 1);
|
data.merge(value, value, value, 1);
|
||||||
@ -269,15 +205,21 @@ public class CalculateAverage_C5H12O5 {
|
|||||||
else {
|
else {
|
||||||
result.put(station, new MeasurementData(value));
|
result.put(station, new MeasurementData(value));
|
||||||
}
|
}
|
||||||
|
start = end + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert the byte array to int.
|
* Convert the byte array to int.
|
||||||
*/
|
*/
|
||||||
private static int toInt(byte[] bytes) {
|
private static int toInt(byte[] bytes, int start) {
|
||||||
boolean negative = false;
|
boolean negative = false;
|
||||||
int result = 0;
|
int result = 0;
|
||||||
for (byte b : bytes) {
|
for (int i = start; i < bytes.length; i++) {
|
||||||
|
byte b = bytes[i];
|
||||||
if (b == '-') {
|
if (b == '-') {
|
||||||
negative = true;
|
negative = true;
|
||||||
continue;
|
continue;
|
||||||
@ -288,27 +230,5 @@ public class CalculateAverage_C5H12O5 {
|
|||||||
}
|
}
|
||||||
return negative ? -result : result;
|
return negative ? -result : result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Split the byte array by given byte.
|
|
||||||
*/
|
|
||||||
private static List<byte[]> split(byte[] bytes, byte separator) {
|
|
||||||
List<byte[]> result = new ArrayList<>();
|
|
||||||
int start = 0;
|
|
||||||
for (int end = 0; end < bytes.length; end++) {
|
|
||||||
if (bytes[end] == separator) {
|
|
||||||
byte[] newBytes = new byte[end - start];
|
|
||||||
System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
|
|
||||||
result.add(newBytes);
|
|
||||||
start = end + 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (start <= bytes.length) {
|
|
||||||
byte[] newBytes = new byte[bytes.length - start];
|
|
||||||
System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
|
|
||||||
result.add(newBytes);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user