1brc/src/main/java/dev/morling/onebrc/CalculateAverage_raipc.java
Anton Rybochkin ad7a6ec573
Initial solution by raipc
* Initial solution by raipc

* Implemented custom hash map with open addressing

* Small optimizations to task splitting and range check disabling

* Fixed off-by-one error in merge

* Run with EpsilonGC. Borrowed VM params from Shipilev

* Make script executable

* Add a license
2024-01-10 22:26:00 +01:00

471 lines
17 KiB
Java

/*
* Copyright 2023 The original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.morling.onebrc;
import java.io.*;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Comparator;
import java.util.concurrent.RecursiveTask;
public class CalculateAverage_raipc {
private static final String FILE = "./measurements.txt";
private static final int BUFFER_SIZE = 16 * 1024;
private static final class AggregatedMeasurement {
final ByteArrayWrapper station;
private String stationStringCached;
private int min = Integer.MAX_VALUE;
private int max = Integer.MIN_VALUE;
private int count;
private long sum;
private AggregatedMeasurement(ByteArrayWrapper station) {
this.station = station;
}
public String station() {
String s = this.stationStringCached;
if (s == null) {
this.stationStringCached = s = station.toString();
}
return s;
}
public AggregatedMeasurement merge(AggregatedMeasurement other) {
this.min = Math.min(this.min, other.min);
this.max = Math.max(this.max, other.max);
this.sum += other.sum;
this.count += other.count;
return this;
}
public void add(int value) {
this.min = Math.min(this.min, value);
this.max = Math.max(this.max, value);
this.sum += value;
this.count++;
}
public void format(StringBuilder out) {
out.append(station()).append('=')
.append(min / 10.0).append('/')
.append(Math.round(1.0 * sum / count) / 10.0).append('/')
.append(max / 10.0);
}
}
public static void main(String[] args) throws InterruptedException {
File inputFile = new File(FILE);
ParsingTask parsingTask = new ParsingTask(inputFile, 0, inputFile.length());
AggregatedMeasurement[] results = parsingTask.fork().join().toArray();
System.out.println(formatResult(results));
}
private static String formatResult(AggregatedMeasurement[] result) {
Arrays.sort(result, Comparator.comparing(AggregatedMeasurement::station));
StringBuilder out = new StringBuilder().append('{');
for (AggregatedMeasurement item : result) {
item.format(out);
out.append(", ");
}
if (out.length() > 2) {
out.setLength(out.length() - 2);
}
return out.append('}').toString();
}
private static class ParsingTask extends RecursiveTask<MyHashMap> {
private static final int SPLIT_FACTOR = 4;
private static final int MIN_TASK_SIZE = 256 * 1024;
private final File file;
private final long startPosition;
private final long endPosition;
private ParsingTask(File file, long startPosition, long endPosition) {
this.file = file;
this.startPosition = startPosition;
this.endPosition = endPosition;
}
@Override
protected MyHashMap compute() {
long size = endPosition - startPosition;
if (size <= MIN_TASK_SIZE || size < file.length() / Runtime.getRuntime().availableProcessors() / SPLIT_FACTOR) {
return doCompute();
}
var firstHalf = new ParsingTask(file, startPosition, (startPosition + endPosition) / 2).fork();
var secondHalf = new ParsingTask(file, (startPosition + endPosition) / 2, endPosition).fork();
var firstHalfResults = firstHalf.join();
var secondHalfResults = secondHalf.join();
firstHalfResults.merge(secondHalfResults);
return firstHalfResults;
}
private MyHashMap doCompute() {
try (RandomAccessFile raf = new RandomAccessFile(file, "r")) {
return new ParsingRoutine(startPosition, endPosition).parse(raf);
}
catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}
private static class ParsingRoutine {
private final MyHashMap result = new MyHashMap(2048);
private final byte[] partialContentBuff = new byte[128];
private final ByteArrayWrapper reusableWrapper = new ByteArrayWrapper(partialContentBuff, 0, 0, 0);
private int partialSize;
private final long startPosition;
private final long endPosition;
private ParsingRoutine(long startPosition, long endPosition) {
this.startPosition = startPosition;
this.endPosition = endPosition;
}
private boolean hasPartialContent() {
return partialSize > 0;
}
MyHashMap parse(RandomAccessFile raf) throws IOException {
byte[] buffer = new byte[BUFFER_SIZE];
long offset = findStartOffset(raf, buffer);
boolean readMore = offset <= endPosition;
while (readMore) {
raf.seek(offset);
int length = raf.read(buffer);
if (length == -1) {
break;
}
int bufStart = 0;
if (hasPartialContent()) {
int idxOfLf = indexOf(buffer, '\n', 0, length);
if (idxOfLf >= 0) {
bufStart = idxOfLf + 1;
System.arraycopy(buffer, 0, partialContentBuff, partialSize, bufStart);
int toProcess = partialSize + bufStart;
partialSize = 0;
processPart(offset - toProcess, partialContentBuff, 0, toProcess);
}
else {
System.arraycopy(buffer, 0, partialContentBuff, partialSize, bufStart);
offset += length;
continue;
}
}
readMore = processPart(offset, buffer, bufStart, length);
offset += length;
}
return result;
}
long findStartOffset(RandomAccessFile raf, byte[] buffer) throws IOException {
if (startPosition == 0) {
return 0;
}
long offset = startPosition - 1;
int length = 0;
int idxOfLf = -1;
while (offset < endPosition) {
raf.seek(offset);
length = raf.read(buffer);
if (length == -1) {
throw new IllegalStateException("No content read on position " + offset);
}
offset += length;
idxOfLf = indexOf(buffer, '\n', 0, length);
if (idxOfLf >= 0) {
break;
}
}
if (offset > startPosition) {
int start = idxOfLf + 1;
processPart(offset + start - length, buffer, start, length);
}
return offset;
}
boolean processPart(long position, byte[] buf, int start, int end) {
ByteArrayWrapper key = reusableWrapper;
key.content = buf;
while (position <= endPosition) {
int idxOfSemicolon = indexOf(buf, ';', start, end);
if (idxOfSemicolon >= 0) {
int idxOfLf = indexOf(buf, '\n', idxOfSemicolon + 2, Math.max(idxOfSemicolon + 2, end));
if (idxOfLf >= 0) {
key.start = start;
key.length = idxOfSemicolon - start;
var aggregatedMeasurement = result.getOrCreate(key);
aggregatedMeasurement.add(parseInt(buf, idxOfSemicolon + 1, idxOfLf - 1));
int prevStart = start;
start = idxOfLf + 1;
position += start - prevStart;
continue;
}
}
partialSize = end - start;
if (partialSize > 0) {
System.arraycopy(buf, start, partialContentBuff, 0, partialSize);
}
break;
}
return hasPartialContent() || position < endPosition;
}
private int parseInt(byte[] buf, int from, int toIncl) {
int mul = 1;
if (buf[from] == '-') {
mul = -1;
++from;
}
int res = buf[toIncl] - '0';
int dec = 10;
for (int i = toIncl - 2; i >= from; --i) {
res += (buf[i] - '0') * dec;
dec *= 10;
}
return mul * res;
}
}
private static final MethodHandle indexOfMH;
private static final MethodHandle vectorizedHashCodeMH;
private static final MethodHandle mismatchMH;
static {
try {
Class<?> stringLatin1 = Class.forName("java.lang.StringLatin1");
var lookup = MethodHandles.privateLookupIn(stringLatin1, MethodHandles.lookup());
// int indexOf(byte[] value, int ch, int fromIndex, int toIndex)
indexOfMH = lookup.findStatic(stringLatin1, "indexOf",
MethodType.methodType(int.class, byte[].class, int.class, int.class, int.class));
Class<?> arraysSupport = Class.forName("jdk.internal.util.ArraysSupport");
lookup = MethodHandles.privateLookupIn(arraysSupport, MethodHandles.lookup());
// int vectorizedHashCode(Object array, int fromIndex, int length, int initialValue, int basicType)
vectorizedHashCodeMH = lookup.findStatic(arraysSupport, "vectorizedHashCode",
MethodType.methodType(int.class, Object.class, int.class, int.class, int.class, int.class));
lookup = MethodHandles.privateLookupIn(arraysSupport, MethodHandles.lookup());
// int mismatch(byte[] a, int aFromIndex, byte[] b, int bFromIndex, int length)
mismatchMH = lookup.findStatic(arraysSupport, "mismatch",
MethodType.methodType(int.class, byte[].class, int.class, byte[].class, int.class, int.class));
}
catch (Exception e) {
throw new Error(e);
}
}
static int indexOf(byte[] src, int ch, int fromIndex, int toIndex) {
try {
return (int) indexOfMH.invoke(src, ch, fromIndex, toIndex);
}
catch (Throwable e) {
throw new Error(e);
}
}
static int vectorizedHashCode(byte[] array, int fromIndex, int length) {
try {
return (int) vectorizedHashCodeMH.invoke((Object) array, fromIndex, length, 0, 8);
}
catch (Throwable e) {
throw new Error(e);
}
}
static boolean arraysEqual(byte[] a, int aFromIndex, byte[] b, int bFromIndex, int length) {
try {
return ((int) mismatchMH.invoke(a, aFromIndex, b, bFromIndex, length)) < 0;
}
catch (Throwable e) {
throw new Error(e);
}
}
private static class ByteArrayWrapper {
private byte[] content;
private int start;
private int length;
int hashCodeCached;
ByteArrayWrapper(byte[] content, int start, int length, int hashCodeCached) {
this.content = content;
this.start = start;
this.length = length;
this.hashCodeCached = hashCodeCached;
}
int calculateHashCode() {
return this.hashCodeCached = vectorizedHashCode(content, start, length);
}
@Override
public boolean equals(Object obj) {
return obj instanceof ByteArrayWrapper bw && isEqualTo(bw);
}
boolean isEqualTo(ByteArrayWrapper other) {
return length == other.length && arraysEqual(content, start, other.content, other.start, length);
}
@Override
public int hashCode() {
return hashCodeCached;
}
@Override
public String toString() {
return new String(content, start, length, StandardCharsets.UTF_8);
}
ByteArrayWrapper copy() {
byte[] copy = Arrays.copyOfRange(content, start, start + length);
return new ByteArrayWrapper(copy, 0, copy.length, hashCodeCached);
}
}
private static class MyHashMap {
private static final int INVERSE_LOAD_FACTOR = 2;
private AggregatedMeasurement[] data;
private int size;
MyHashMap(int initialCapacity) {
this.data = new AggregatedMeasurement[initialCapacity];
}
private void grow() {
AggregatedMeasurement[] oldData = data;
AggregatedMeasurement[] newData = new AggregatedMeasurement[oldData.length * 2];
for (AggregatedMeasurement measurement : oldData) {
if (measurement != null) {
put(measurement, newData);
}
}
this.data = newData;
}
private void put(AggregatedMeasurement value, AggregatedMeasurement[] data) {
int length = data.length;
int hashCode = value.hashCode();
int idx = length & (hashCode ^ (hashCode >> 16));
for (int i = idx; i < length; ++i) {
if (data[i] == null) {
data[i] = value;
return;
}
}
for (int i = 0; i < idx; ++i) {
if (data[i] == null) {
data[i] = value;
return;
}
}
}
AggregatedMeasurement getOrCreate(ByteArrayWrapper key) {
AggregatedMeasurement[] data = this.data;
int length = data.length;
int hashCode = key.calculateHashCode();
int idx = (length - 1) & (hashCode ^ (hashCode >> 16));
AggregatedMeasurement result = doGetOrCreate(key, idx, length, data);
return result != null ? result : doGetOrCreate(key, 0, idx, data);
}
private AggregatedMeasurement doGetOrCreate(ByteArrayWrapper key, int from, int to, AggregatedMeasurement[] data) {
for (int i = from; i < to; ++i) {
AggregatedMeasurement item = data[i];
if (item != null) {
if (item.station.isEqualTo(key)) {
return item;
}
}
else {
AggregatedMeasurement result = new AggregatedMeasurement(key.copy());
++size;
if (size * INVERSE_LOAD_FACTOR >= data.length) {
grow();
put(result, this.data);
}
else {
data[i] = result;
}
return result;
}
}
return null;
}
void merge(MyHashMap other) {
for (AggregatedMeasurement measurement : other.data) {
if (measurement != null) {
merge(measurement);
}
}
}
private boolean merge(AggregatedMeasurement value) {
AggregatedMeasurement[] data = this.data;
int length = data.length;
ByteArrayWrapper key = value.station;
int hashCode = key.hashCode();
int idx = (length - 1) & (hashCode ^ (hashCode >> 16));
return doMerge(key, value, idx, length, data) || doMerge(key, value, 0, idx, data);
}
private boolean doMerge(ByteArrayWrapper key, AggregatedMeasurement value, int from, int to, AggregatedMeasurement[] data) {
for (int i = from; i < to; ++i) {
AggregatedMeasurement item = data[i];
if (item != null) {
if (item.station.isEqualTo(key)) {
item.merge(value);
return true;
}
}
else {
++size;
if (size * INVERSE_LOAD_FACTOR >= data.length) {
grow();
put(value, this.data);
}
else {
data[i] = value;
}
return true;
}
}
return false;
}
AggregatedMeasurement[] toArray() {
AggregatedMeasurement[] result = new AggregatedMeasurement[size];
int i = 0;
for (AggregatedMeasurement measurement : data) {
if (measurement != null) {
result[i++] = measurement;
}
}
return result;
}
}
}