Processing byte array backwards (#504)
This commit is contained in:
parent
51f8ecfa43
commit
8353a1cb3d
@ -15,136 +15,386 @@
|
||||
*/
|
||||
package dev.morling.onebrc;
|
||||
|
||||
import sun.misc.Unsafe;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.RandomAccessFile;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.channels.AsynchronousFileChannel;
|
||||
import java.nio.channels.CompletionHandler;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.TreeMap;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.FutureTask;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.LinkedTransferQueue;
|
||||
import java.util.concurrent.TransferQueue;
|
||||
|
||||
/**
|
||||
* Calculates the average using AIO and multiple threads.
|
||||
* Results on Mac mini (Apple M2 with 8-core CPU / 8GB unified memory):
|
||||
* <pre>
|
||||
* using AIO and multiple threads:
|
||||
* 120.15s user 4.33s system 710% cpu 17.522 total
|
||||
*
|
||||
* reduce the number of memory copies:
|
||||
* 45.87s user 2.82s system 530% cpu 9.185 total
|
||||
*
|
||||
* processing byte array backwards and using bitwise operation to find specific byte (inspired by thomaswue):
|
||||
* 25.38s user 3.44s system 342% cpu 8.406 total
|
||||
* </pre>
|
||||
*
|
||||
* @author Xylitol
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public class CalculateAverage_C5H12O5 {
|
||||
private static final int BUFFER_CAPACITY = 1024 * 1024 * 10;
|
||||
private static final int MAP_CAPACITY = 10000;
|
||||
private static final int PROCESSORS = Runtime.getRuntime().availableProcessors();
|
||||
private static final BlockingQueue<byte[]> BYTES_QUEUE = new LinkedBlockingQueue<>(PROCESSORS);
|
||||
private static long readPosition;
|
||||
private static final int AVAILABLE_PROCESSOR_NUM = Runtime.getRuntime().availableProcessors();
|
||||
private static final int TRANSFER_QUEUE_CAPACITY = 1024 / 16 / AVAILABLE_PROCESSOR_NUM; // 1GB memory max
|
||||
private static final int BYTE_BUFFER_CAPACITY = 1024 * 1024 * 16; // 16MB one time
|
||||
private static final int EXPECTED_MAPPINGS_NUM = 10000;
|
||||
|
||||
/**
|
||||
* Fragment the file into chunks.
|
||||
*/
|
||||
private static long[] fragment(Path path) throws IOException {
|
||||
long size = Files.size(path);
|
||||
long chunk = size / AVAILABLE_PROCESSOR_NUM;
|
||||
List<Long> positions = new ArrayList<>();
|
||||
try (RandomAccessFile file = new RandomAccessFile(path.toFile(), "r")) {
|
||||
long position = chunk;
|
||||
for (int i = 0; i < AVAILABLE_PROCESSOR_NUM - 1; i++) {
|
||||
if (position >= size) {
|
||||
break;
|
||||
}
|
||||
file.seek(position);
|
||||
// move the position to the next newline byte
|
||||
while (file.read() != '\n') {
|
||||
position++;
|
||||
}
|
||||
positions.add(++position);
|
||||
position += chunk;
|
||||
}
|
||||
}
|
||||
if (positions.isEmpty() || positions.getLast() < size) {
|
||||
positions.add(size);
|
||||
}
|
||||
return positions.stream().mapToLong(Long::longValue).toArray();
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
System.out.println(calc("./measurements.txt"));
|
||||
// fragment the input file
|
||||
Path path = Path.of("./measurements.txt");
|
||||
long[] positions = fragment(path);
|
||||
|
||||
// start the calculation tasks
|
||||
FutureTask<Map<Station, MeasurementData>>[] tasks = new FutureTask[positions.length];
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
tasks[i] = new FutureTask<>(new Calculator(path, (i == 0 ? 0 : positions[i - 1]), positions[i]));
|
||||
new Thread(tasks[i]).start();
|
||||
}
|
||||
|
||||
// wait for the results
|
||||
Map<Station, MeasurementData> result = HashMap.newHashMap(EXPECTED_MAPPINGS_NUM);
|
||||
for (FutureTask<Map<Station, MeasurementData>> task : tasks) {
|
||||
task.get().forEach((k, v) -> result.merge(k, v, MeasurementData::merge));
|
||||
}
|
||||
|
||||
// sort and print the results
|
||||
TreeMap<String, MeasurementData> sorted = new TreeMap<>();
|
||||
for (Map.Entry<Station, MeasurementData> entry : result.entrySet()) {
|
||||
sorted.put(new String(entry.getKey().bytes, StandardCharsets.UTF_8), entry.getValue());
|
||||
}
|
||||
System.out.println(sorted);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the average.
|
||||
* The calculation task.
|
||||
*/
|
||||
public static String calc(String path) throws IOException, ExecutionException, InterruptedException {
|
||||
readPosition = 0;
|
||||
Map<String, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
|
||||
// read and offer to queue
|
||||
try (AsynchronousFileChannel channel = AsynchronousFileChannel.open(
|
||||
Paths.get(path), Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor())) {
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY);
|
||||
channel.read(buffer, readPosition, buffer, new CompletionHandler<>() {
|
||||
private static class Calculator implements Callable<Map<Station, MeasurementData>> {
|
||||
private final TransferQueue<byte[]> transfer = new LinkedTransferQueue<>();
|
||||
private final AsynchronousFileChannel asyncChannel;
|
||||
private final long limit;
|
||||
private long position;
|
||||
|
||||
public Calculator(Path file, long position, long limit) throws IOException {
|
||||
ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
|
||||
this.asyncChannel = AsynchronousFileChannel.open(file, Set.of(StandardOpenOption.READ), executor);
|
||||
this.position = position;
|
||||
this.limit = limit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void completed(Integer bytesRead, ByteBuffer buffer) {
|
||||
try {
|
||||
if (bytesRead > 0) {
|
||||
public Map<Station, MeasurementData> call() throws InterruptedException {
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(BYTE_BUFFER_CAPACITY);
|
||||
asyncChannel.read(buffer, position, buffer, new CompletionHandler<>() {
|
||||
@Override
|
||||
public void completed(Integer readSize, ByteBuffer buffer) {
|
||||
if (position + readSize >= limit) {
|
||||
buffer.limit(readSize - (int) (position + readSize - limit));
|
||||
}
|
||||
else {
|
||||
for (int i = buffer.position() - 1; i >= 0; i--) {
|
||||
if (buffer.get(i) == '\n') {
|
||||
// truncate the buffer to the last newline byte
|
||||
buffer.limit(i + 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
buffer.flip();
|
||||
byte[] bytes = new byte[buffer.remaining()];
|
||||
buffer.get(bytes);
|
||||
readPosition += buffer.limit();
|
||||
BYTES_QUEUE.put(bytes);
|
||||
byte[] bytes = new byte[buffer.limit() + 1];
|
||||
// add a newline byte at the beginning
|
||||
bytes[0] = '\n';
|
||||
buffer.get(bytes, 1, buffer.limit());
|
||||
transfer(bytes);
|
||||
if ((position += buffer.limit()) < limit) {
|
||||
buffer.clear();
|
||||
channel.read(buffer, readPosition, buffer, this);
|
||||
asyncChannel.read(buffer, position, buffer, this);
|
||||
}
|
||||
else {
|
||||
for (int i = 0; i < PROCESSORS; i++) {
|
||||
BYTES_QUEUE.put(new byte[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
// stop signal
|
||||
transfer(new byte[0]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void failed(Throwable exc, ByteBuffer buffer) {
|
||||
// ignore
|
||||
transfer(new byte[0]);
|
||||
}
|
||||
});
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
FutureTask<Map<MeasurementName, MeasurementData>>[] tasks = new FutureTask[PROCESSORS];
|
||||
for (int i = 0; i < PROCESSORS; i++) {
|
||||
tasks[i] = new FutureTask<>(new Task());
|
||||
new Thread(tasks[i]).start();
|
||||
}
|
||||
for (FutureTask<Map<MeasurementName, MeasurementData>> task : tasks) {
|
||||
task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge));
|
||||
}
|
||||
}
|
||||
return new TreeMap<>(result).toString();
|
||||
return process();
|
||||
}
|
||||
|
||||
/**
|
||||
* The measurement name.
|
||||
* Transfer or put the bytes to the queue.
|
||||
*/
|
||||
private record MeasurementName(byte[] bytes, int length) {
|
||||
private void transfer(byte[] bytes) {
|
||||
try {
|
||||
if (transfer.size() >= TRANSFER_QUEUE_CAPACITY) {
|
||||
transfer.transfer(bytes);
|
||||
}
|
||||
else {
|
||||
transfer.put(bytes);
|
||||
}
|
||||
}
|
||||
catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Take and process the bytes from the queue.
|
||||
*/
|
||||
private Map<Station, MeasurementData> process() throws InterruptedException {
|
||||
Map<Station, MeasurementData> result = HashMap.newHashMap(EXPECTED_MAPPINGS_NUM);
|
||||
for (byte[] bytes = transfer.take(); bytes.length > 0; bytes = transfer.take()) {
|
||||
Station station = new Station(bytes);
|
||||
// read the bytes backwards
|
||||
for (int position = bytes.length - 2; position >= 1; position--) {
|
||||
|
||||
// calculate the temperature value
|
||||
int temperature = bytes[position] - '0' + (bytes[position -= 2] - '0') * 10;
|
||||
byte unknownByte = bytes[--position];
|
||||
int semicolon = switch (unknownByte) {
|
||||
case ';' -> position;
|
||||
case '-' -> {
|
||||
temperature = -temperature;
|
||||
yield --position;
|
||||
}
|
||||
default -> {
|
||||
temperature += (unknownByte - '0') * 100;
|
||||
if (bytes[--position] == '-') {
|
||||
temperature = -temperature;
|
||||
--position;
|
||||
}
|
||||
yield position;
|
||||
}
|
||||
};
|
||||
|
||||
// calculate the station name hash
|
||||
int hash = 1;
|
||||
while (true) {
|
||||
long temp = LineFinder.previousLong(bytes, position);
|
||||
int distance = LineFinder.NATIVE.fromRight(temp);
|
||||
if (distance == 0) {
|
||||
// current byte is '\n'
|
||||
break;
|
||||
}
|
||||
position -= distance;
|
||||
if (distance == 8) {
|
||||
// can't find '\n' in previous 8 bytes
|
||||
hash = 31 * hash + (int) (temp ^ (temp >>> 32));
|
||||
continue;
|
||||
}
|
||||
// clear the redundant bytes
|
||||
temp = LineFinder.NATIVE.clearLeft(temp, distance);
|
||||
hash = 31 * hash + (int) (temp ^ (temp >>> 32));
|
||||
}
|
||||
|
||||
// merge data to the result map
|
||||
MeasurementData data = result.get(station.slice(hash, position + 1, semicolon));
|
||||
if (data == null) {
|
||||
result.put(station.copy(), new MeasurementData(temperature));
|
||||
} else {
|
||||
data.merge(temperature);
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* To find the nearest newline byte position in a long.
|
||||
*/
|
||||
private interface LineFinder {
|
||||
// choose the implementation according to the native byte order
|
||||
LineFinder NATIVE = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? LELineFinder.INST : BELineFinder.INST;
|
||||
|
||||
Unsafe UNSAFE = initUnsafe();
|
||||
int BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class);
|
||||
int LONG_BYTES = Long.SIZE / Byte.SIZE;
|
||||
|
||||
static Unsafe initUnsafe() {
|
||||
try {
|
||||
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
|
||||
theUnsafe.setAccessible(true);
|
||||
return (Unsafe) theUnsafe.get(Unsafe.class);
|
||||
}
|
||||
catch (NoSuchFieldException | IllegalAccessException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
static long previousLong(byte[] bytes, long offset) {
|
||||
return UNSAFE.getLong(bytes, BYTE_ARRAY_BASE_OFFSET + offset + 1 - LONG_BYTES);
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark the highest bit of newline byte (0x0A) to 1.
|
||||
*/
|
||||
static long markHighestBit(long longBytes) {
|
||||
long temp = longBytes ^ 0x0A0A0A0A0A0A0A0AL;
|
||||
return (temp - 0x0101010101010101L) & ~temp & 0x8080808080808080L;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the nearest newline byte position from right to left.
|
||||
*/
|
||||
int fromRight(long longBytes);
|
||||
|
||||
/**
|
||||
* Clear the left bytes out of the range.
|
||||
*/
|
||||
long clearLeft(long longBytes, int keepNum);
|
||||
|
||||
enum LELineFinder implements LineFinder {
|
||||
INST;
|
||||
|
||||
private static final long[] MASKS = new long[8];
|
||||
|
||||
static {
|
||||
for (int i = 1; i <= 7; i++) {
|
||||
MASKS[i] = 0xFFFFFFFFFFFFFFFFL << ((8 - i) << 3);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object name) {
|
||||
MeasurementName other = (MeasurementName) name;
|
||||
if (other.length != length) {
|
||||
return false;
|
||||
public int fromRight(long longBytes) {
|
||||
return Long.numberOfLeadingZeros(markHighestBit(longBytes)) >>> 3;
|
||||
}
|
||||
return Arrays.compare(bytes, 0, length, other.bytes, 0, length) == 0;
|
||||
|
||||
@Override
|
||||
public long clearLeft(long longBytes, int keepNum) {
|
||||
return longBytes & MASKS[keepNum];
|
||||
}
|
||||
}
|
||||
|
||||
enum BELineFinder implements LineFinder {
|
||||
INST;
|
||||
|
||||
private static final long[] MASKS = new long[8];
|
||||
|
||||
static {
|
||||
for (int i = 1; i <= 7; i++) {
|
||||
MASKS[i] = 0xFFFFFFFFFFFFFFFFL >>> ((8 - i) << 3);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int fromRight(long longBytes) {
|
||||
return Long.numberOfTrailingZeros(markHighestBit(longBytes)) >>> 3;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long clearLeft(long longBytes, int keepNum) {
|
||||
return longBytes & MASKS[keepNum];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The station name wrapper ( bytes[from, to) ).
|
||||
*/
|
||||
private static class Station {
|
||||
private final byte[] bytes;
|
||||
private int from;
|
||||
private int to;
|
||||
private int hash;
|
||||
|
||||
public Station(byte[] bytes) {
|
||||
this(bytes, 0, 0, 0);
|
||||
}
|
||||
|
||||
public Station(byte[] bytes, int hash, int from, int to) {
|
||||
this.bytes = bytes;
|
||||
this.slice(hash, from, to);
|
||||
}
|
||||
|
||||
public Station slice(int hash, int from, int to) {
|
||||
this.hash = hash;
|
||||
this.from = from;
|
||||
this.to = to;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Station copy() {
|
||||
int length = to - from;
|
||||
byte[] newBytes = new byte[length];
|
||||
System.arraycopy(bytes, from, newBytes, 0, length);
|
||||
return new Station(newBytes, hash, 0, length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object station) {
|
||||
Station other = (Station) station;
|
||||
return Arrays.equals(bytes, from, to, other.bytes, other.from, other.to);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = 1;
|
||||
for (int i = 0; i < length; i++) {
|
||||
result = 31 * result + bytes[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return new String(bytes, 0, length, StandardCharsets.UTF_8);
|
||||
return hash;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The measurement data.
|
||||
* The measurement data wrapper ( temperature * 10 ).
|
||||
*/
|
||||
private static class MeasurementData {
|
||||
private int min;
|
||||
private int max;
|
||||
private int sum;
|
||||
private long sum;
|
||||
private int count;
|
||||
|
||||
public MeasurementData(int value) {
|
||||
@ -154,11 +404,15 @@ public class CalculateAverage_C5H12O5 {
|
||||
this.count = 1;
|
||||
}
|
||||
|
||||
public MeasurementData merge(MeasurementData data) {
|
||||
return merge(data.min, data.max, data.sum, data.count);
|
||||
public MeasurementData merge(int value) {
|
||||
return merge(value, value, value, 1);
|
||||
}
|
||||
|
||||
public MeasurementData merge(int min, int max, int sum, int count) {
|
||||
public MeasurementData merge(MeasurementData other) {
|
||||
return merge(other.min, other.max, other.sum, other.count);
|
||||
}
|
||||
|
||||
public MeasurementData merge(int min, int max, long sum, int count) {
|
||||
this.min = Math.min(this.min, min);
|
||||
this.max = Math.max(this.max, max);
|
||||
this.sum += sum;
|
||||
@ -168,67 +422,7 @@ public class CalculateAverage_C5H12O5 {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return (min / 10.0) + "/" + (Math.round((double) sum / count) / 10.0) + "/" + (max / 10.0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The task to calculate.
|
||||
*/
|
||||
private static class Task implements Callable<Map<MeasurementName, MeasurementData>> {
|
||||
|
||||
@Override
|
||||
public Map<MeasurementName, MeasurementData> call() throws InterruptedException {
|
||||
// poll from queue and calculate
|
||||
Map<MeasurementName, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
|
||||
for (byte[] bytes = BYTES_QUEUE.take(); true; bytes = BYTES_QUEUE.take()) {
|
||||
if (bytes.length == 0) {
|
||||
break;
|
||||
}
|
||||
int start = 0;
|
||||
for (int end = 0; end < bytes.length; end++) {
|
||||
if (bytes[end] == '\n') {
|
||||
byte[] newBytes = new byte[end - start];
|
||||
System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
|
||||
int semicolon = newBytes.length - 4;
|
||||
for (; semicolon >= 0; semicolon--) {
|
||||
if (newBytes[semicolon] == ';') {
|
||||
break;
|
||||
}
|
||||
}
|
||||
MeasurementName station = new MeasurementName(newBytes, semicolon);
|
||||
int value = toInt(newBytes, semicolon + 1);
|
||||
MeasurementData data = result.get(station);
|
||||
if (data != null) {
|
||||
data.merge(value, value, value, 1);
|
||||
}
|
||||
else {
|
||||
result.put(station, new MeasurementData(value));
|
||||
}
|
||||
start = end + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert the byte array to int.
|
||||
*/
|
||||
private static int toInt(byte[] bytes, int start) {
|
||||
boolean negative = false;
|
||||
int result = 0;
|
||||
for (int i = start; i < bytes.length; i++) {
|
||||
byte b = bytes[i];
|
||||
if (b == '-') {
|
||||
negative = true;
|
||||
continue;
|
||||
}
|
||||
if (b != '.') {
|
||||
result = result * 10 + (b - '0');
|
||||
}
|
||||
}
|
||||
return negative ? -result : result;
|
||||
return STR."\{min / 10.0}/\{Math.round((double) sum / count) / 10.0}/\{max / 10.0}";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user