Processing byte array backwards (#504)

This commit is contained in:
Xylitol 2024-01-20 21:04:19 +08:00 committed by GitHub
parent 51f8ecfa43
commit 8353a1cb3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,136 +15,386 @@
*/
package dev.morling.onebrc;
import sun.misc.Unsafe;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.AsynchronousFileChannel;
import java.nio.channels.CompletionHandler;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.TransferQueue;
/**
* Calculates the average using AIO and multiple threads.
* Results on Mac mini (Apple M2 with 8-core CPU / 8GB unified memory):
* <pre>
* using AIO and multiple threads:
* 120.15s user 4.33s system 710% cpu 17.522 total
*
* reduce the number of memory copies:
* 45.87s user 2.82s system 530% cpu 9.185 total
*
* processing byte array backwards and using bitwise operation to find specific byte (inspired by thomaswue):
* 25.38s user 3.44s system 342% cpu 8.406 total
* </pre>
*
* @author Xylitol
*/
@SuppressWarnings("unchecked")
public class CalculateAverage_C5H12O5 {
private static final int BUFFER_CAPACITY = 1024 * 1024 * 10;
private static final int MAP_CAPACITY = 10000;
private static final int PROCESSORS = Runtime.getRuntime().availableProcessors();
private static final BlockingQueue<byte[]> BYTES_QUEUE = new LinkedBlockingQueue<>(PROCESSORS);
private static long readPosition;
private static final int AVAILABLE_PROCESSOR_NUM = Runtime.getRuntime().availableProcessors();
private static final int TRANSFER_QUEUE_CAPACITY = 1024 / 16 / AVAILABLE_PROCESSOR_NUM; // 1GB memory max
private static final int BYTE_BUFFER_CAPACITY = 1024 * 1024 * 16; // 16MB one time
private static final int EXPECTED_MAPPINGS_NUM = 10000;
/**
* Fragment the file into chunks.
*/
private static long[] fragment(Path path) throws IOException {
long size = Files.size(path);
long chunk = size / AVAILABLE_PROCESSOR_NUM;
List<Long> positions = new ArrayList<>();
try (RandomAccessFile file = new RandomAccessFile(path.toFile(), "r")) {
long position = chunk;
for (int i = 0; i < AVAILABLE_PROCESSOR_NUM - 1; i++) {
if (position >= size) {
break;
}
file.seek(position);
// move the position to the next newline byte
while (file.read() != '\n') {
position++;
}
positions.add(++position);
position += chunk;
}
}
if (positions.isEmpty() || positions.getLast() < size) {
positions.add(size);
}
return positions.stream().mapToLong(Long::longValue).toArray();
}
public static void main(String[] args) throws Exception {
System.out.println(calc("./measurements.txt"));
// fragment the input file
Path path = Path.of("./measurements.txt");
long[] positions = fragment(path);
// start the calculation tasks
FutureTask<Map<Station, MeasurementData>>[] tasks = new FutureTask[positions.length];
for (int i = 0; i < positions.length; i++) {
tasks[i] = new FutureTask<>(new Calculator(path, (i == 0 ? 0 : positions[i - 1]), positions[i]));
new Thread(tasks[i]).start();
}
// wait for the results
Map<Station, MeasurementData> result = HashMap.newHashMap(EXPECTED_MAPPINGS_NUM);
for (FutureTask<Map<Station, MeasurementData>> task : tasks) {
task.get().forEach((k, v) -> result.merge(k, v, MeasurementData::merge));
}
// sort and print the results
TreeMap<String, MeasurementData> sorted = new TreeMap<>();
for (Map.Entry<Station, MeasurementData> entry : result.entrySet()) {
sorted.put(new String(entry.getKey().bytes, StandardCharsets.UTF_8), entry.getValue());
}
System.out.println(sorted);
}
/**
* Calculate the average.
* The calculation task.
*/
public static String calc(String path) throws IOException, ExecutionException, InterruptedException {
readPosition = 0;
Map<String, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
// read and offer to queue
try (AsynchronousFileChannel channel = AsynchronousFileChannel.open(
Paths.get(path), Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor())) {
ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY);
channel.read(buffer, readPosition, buffer, new CompletionHandler<>() {
private static class Calculator implements Callable<Map<Station, MeasurementData>> {
private final TransferQueue<byte[]> transfer = new LinkedTransferQueue<>();
private final AsynchronousFileChannel asyncChannel;
private final long limit;
private long position;
public Calculator(Path file, long position, long limit) throws IOException {
ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
this.asyncChannel = AsynchronousFileChannel.open(file, Set.of(StandardOpenOption.READ), executor);
this.position = position;
this.limit = limit;
}
@Override
public Map<Station, MeasurementData> call() throws InterruptedException {
ByteBuffer buffer = ByteBuffer.allocateDirect(BYTE_BUFFER_CAPACITY);
asyncChannel.read(buffer, position, buffer, new CompletionHandler<>() {
@Override
public void completed(Integer bytesRead, ByteBuffer buffer) {
try {
if (bytesRead > 0) {
for (int i = buffer.position() - 1; i >= 0; i--) {
if (buffer.get(i) == '\n') {
buffer.limit(i + 1);
break;
}
}
buffer.flip();
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);
readPosition += buffer.limit();
BYTES_QUEUE.put(bytes);
buffer.clear();
channel.read(buffer, readPosition, buffer, this);
}
else {
for (int i = 0; i < PROCESSORS; i++) {
BYTES_QUEUE.put(new byte[0]);
public void completed(Integer readSize, ByteBuffer buffer) {
if (position + readSize >= limit) {
buffer.limit(readSize - (int) (position + readSize - limit));
}
else {
for (int i = buffer.position() - 1; i >= 0; i--) {
if (buffer.get(i) == '\n') {
// truncate the buffer to the last newline byte
buffer.limit(i + 1);
break;
}
}
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
buffer.flip();
byte[] bytes = new byte[buffer.limit() + 1];
// add a newline byte at the beginning
bytes[0] = '\n';
buffer.get(bytes, 1, buffer.limit());
transfer(bytes);
if ((position += buffer.limit()) < limit) {
buffer.clear();
asyncChannel.read(buffer, position, buffer, this);
}
else {
// stop signal
transfer(new byte[0]);
}
}
@Override
public void failed(Throwable exc, ByteBuffer buffer) {
// ignore
transfer(new byte[0]);
}
});
return process();
}
@SuppressWarnings("unchecked")
FutureTask<Map<MeasurementName, MeasurementData>>[] tasks = new FutureTask[PROCESSORS];
for (int i = 0; i < PROCESSORS; i++) {
tasks[i] = new FutureTask<>(new Task());
new Thread(tasks[i]).start();
/**
* Transfer or put the bytes to the queue.
*/
private void transfer(byte[] bytes) {
try {
if (transfer.size() >= TRANSFER_QUEUE_CAPACITY) {
transfer.transfer(bytes);
}
else {
transfer.put(bytes);
}
}
for (FutureTask<Map<MeasurementName, MeasurementData>> task : tasks) {
task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge));
catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
return new TreeMap<>(result).toString();
/**
* Take and process the bytes from the queue.
*/
private Map<Station, MeasurementData> process() throws InterruptedException {
Map<Station, MeasurementData> result = HashMap.newHashMap(EXPECTED_MAPPINGS_NUM);
for (byte[] bytes = transfer.take(); bytes.length > 0; bytes = transfer.take()) {
Station station = new Station(bytes);
// read the bytes backwards
for (int position = bytes.length - 2; position >= 1; position--) {
// calculate the temperature value
int temperature = bytes[position] - '0' + (bytes[position -= 2] - '0') * 10;
byte unknownByte = bytes[--position];
int semicolon = switch (unknownByte) {
case ';' -> position;
case '-' -> {
temperature = -temperature;
yield --position;
}
default -> {
temperature += (unknownByte - '0') * 100;
if (bytes[--position] == '-') {
temperature = -temperature;
--position;
}
yield position;
}
};
// calculate the station name hash
int hash = 1;
while (true) {
long temp = LineFinder.previousLong(bytes, position);
int distance = LineFinder.NATIVE.fromRight(temp);
if (distance == 0) {
// current byte is '\n'
break;
}
position -= distance;
if (distance == 8) {
// can't find '\n' in previous 8 bytes
hash = 31 * hash + (int) (temp ^ (temp >>> 32));
continue;
}
// clear the redundant bytes
temp = LineFinder.NATIVE.clearLeft(temp, distance);
hash = 31 * hash + (int) (temp ^ (temp >>> 32));
}
// merge data to the result map
MeasurementData data = result.get(station.slice(hash, position + 1, semicolon));
if (data == null) {
result.put(station.copy(), new MeasurementData(temperature));
} else {
data.merge(temperature);
}
}
}
return result;
}
}
/**
* The measurement name.
* To find the nearest newline byte position in a long.
*/
private record MeasurementName(byte[] bytes, int length) {
private interface LineFinder {
// choose the implementation according to the native byte order
LineFinder NATIVE = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? LELineFinder.INST : BELineFinder.INST;
Unsafe UNSAFE = initUnsafe();
int BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class);
int LONG_BYTES = Long.SIZE / Byte.SIZE;
static Unsafe initUnsafe() {
try {
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
return (Unsafe) theUnsafe.get(Unsafe.class);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
static long previousLong(byte[] bytes, long offset) {
return UNSAFE.getLong(bytes, BYTE_ARRAY_BASE_OFFSET + offset + 1 - LONG_BYTES);
}
/**
* Mark the highest bit of newline byte (0x0A) to 1.
*/
static long markHighestBit(long longBytes) {
long temp = longBytes ^ 0x0A0A0A0A0A0A0A0AL;
return (temp - 0x0101010101010101L) & ~temp & 0x8080808080808080L;
}
/**
* Find the nearest newline byte position from right to left.
*/
int fromRight(long longBytes);
/**
* Clear the left bytes out of the range.
*/
long clearLeft(long longBytes, int keepNum);
enum LELineFinder implements LineFinder {
INST;
private static final long[] MASKS = new long[8];
static {
for (int i = 1; i <= 7; i++) {
MASKS[i] = 0xFFFFFFFFFFFFFFFFL << ((8 - i) << 3);
}
}
@Override
public int fromRight(long longBytes) {
return Long.numberOfLeadingZeros(markHighestBit(longBytes)) >>> 3;
}
@Override
public long clearLeft(long longBytes, int keepNum) {
return longBytes & MASKS[keepNum];
}
}
enum BELineFinder implements LineFinder {
INST;
private static final long[] MASKS = new long[8];
static {
for (int i = 1; i <= 7; i++) {
MASKS[i] = 0xFFFFFFFFFFFFFFFFL >>> ((8 - i) << 3);
}
}
@Override
public int fromRight(long longBytes) {
return Long.numberOfTrailingZeros(markHighestBit(longBytes)) >>> 3;
}
@Override
public long clearLeft(long longBytes, int keepNum) {
return longBytes & MASKS[keepNum];
}
}
}
/**
* The station name wrapper ( bytes[from, to) ).
*/
private static class Station {
private final byte[] bytes;
private int from;
private int to;
private int hash;
public Station(byte[] bytes) {
this(bytes, 0, 0, 0);
}
public Station(byte[] bytes, int hash, int from, int to) {
this.bytes = bytes;
this.slice(hash, from, to);
}
public Station slice(int hash, int from, int to) {
this.hash = hash;
this.from = from;
this.to = to;
return this;
}
public Station copy() {
int length = to - from;
byte[] newBytes = new byte[length];
System.arraycopy(bytes, from, newBytes, 0, length);
return new Station(newBytes, hash, 0, length);
}
@Override
public boolean equals(Object name) {
MeasurementName other = (MeasurementName) name;
if (other.length != length) {
return false;
}
return Arrays.compare(bytes, 0, length, other.bytes, 0, length) == 0;
public boolean equals(Object station) {
Station other = (Station) station;
return Arrays.equals(bytes, from, to, other.bytes, other.from, other.to);
}
@Override
public int hashCode() {
int result = 1;
for (int i = 0; i < length; i++) {
result = 31 * result + bytes[i];
}
return result;
}
@Override
public String toString() {
return new String(bytes, 0, length, StandardCharsets.UTF_8);
return hash;
}
}
/**
* The measurement data.
* The measurement data wrapper ( temperature * 10 ).
*/
private static class MeasurementData {
private int min;
private int max;
private int sum;
private long sum;
private int count;
public MeasurementData(int value) {
@ -154,11 +404,15 @@ public class CalculateAverage_C5H12O5 {
this.count = 1;
}
public MeasurementData merge(MeasurementData data) {
return merge(data.min, data.max, data.sum, data.count);
public MeasurementData merge(int value) {
return merge(value, value, value, 1);
}
public MeasurementData merge(int min, int max, int sum, int count) {
public MeasurementData merge(MeasurementData other) {
return merge(other.min, other.max, other.sum, other.count);
}
public MeasurementData merge(int min, int max, long sum, int count) {
this.min = Math.min(this.min, min);
this.max = Math.max(this.max, max);
this.sum += sum;
@ -168,67 +422,7 @@ public class CalculateAverage_C5H12O5 {
@Override
public String toString() {
return (min / 10.0) + "/" + (Math.round((double) sum / count) / 10.0) + "/" + (max / 10.0);
}
}
/**
* The task to calculate.
*/
private static class Task implements Callable<Map<MeasurementName, MeasurementData>> {
@Override
public Map<MeasurementName, MeasurementData> call() throws InterruptedException {
// poll from queue and calculate
Map<MeasurementName, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
for (byte[] bytes = BYTES_QUEUE.take(); true; bytes = BYTES_QUEUE.take()) {
if (bytes.length == 0) {
break;
}
int start = 0;
for (int end = 0; end < bytes.length; end++) {
if (bytes[end] == '\n') {
byte[] newBytes = new byte[end - start];
System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
int semicolon = newBytes.length - 4;
for (; semicolon >= 0; semicolon--) {
if (newBytes[semicolon] == ';') {
break;
}
}
MeasurementName station = new MeasurementName(newBytes, semicolon);
int value = toInt(newBytes, semicolon + 1);
MeasurementData data = result.get(station);
if (data != null) {
data.merge(value, value, value, 1);
}
else {
result.put(station, new MeasurementData(value));
}
start = end + 1;
}
}
}
return result;
}
/**
* Convert the byte array to int.
*/
private static int toInt(byte[] bytes, int start) {
boolean negative = false;
int result = 0;
for (int i = start; i < bytes.length; i++) {
byte b = bytes[i];
if (b == '-') {
negative = true;
continue;
}
if (b != '.') {
result = result * 10 + (b - '0');
}
}
return negative ? -result : result;
return STR."\{min / 10.0}/\{Math.round((double) sum / count) / 10.0}/\{max / 10.0}";
}
}
}