ad7a6ec573
* Initial solution by raipc * Implemented custom hash map with open addressing * Small optimizations to task splitting and range check disabling * Fixed off-by-one error in merge * Run with EpsilonGC. Borrowed VM params from Shipilev * Make script executable * Add a license
471 lines
17 KiB
Java
471 lines
17 KiB
Java
/*
|
|
* Copyright 2023 The original authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
package dev.morling.onebrc;
|
|
|
|
import java.io.*;
|
|
import java.lang.invoke.MethodHandle;
|
|
import java.lang.invoke.MethodHandles;
|
|
import java.lang.invoke.MethodType;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.util.Arrays;
|
|
import java.util.Comparator;
|
|
import java.util.concurrent.RecursiveTask;
|
|
|
|
public class CalculateAverage_raipc {
|
|
private static final String FILE = "./measurements.txt";
|
|
private static final int BUFFER_SIZE = 16 * 1024;
|
|
|
|
private static final class AggregatedMeasurement {
|
|
final ByteArrayWrapper station;
|
|
private String stationStringCached;
|
|
private int min = Integer.MAX_VALUE;
|
|
private int max = Integer.MIN_VALUE;
|
|
private int count;
|
|
private long sum;
|
|
|
|
private AggregatedMeasurement(ByteArrayWrapper station) {
|
|
this.station = station;
|
|
}
|
|
|
|
public String station() {
|
|
String s = this.stationStringCached;
|
|
if (s == null) {
|
|
this.stationStringCached = s = station.toString();
|
|
}
|
|
return s;
|
|
}
|
|
|
|
public AggregatedMeasurement merge(AggregatedMeasurement other) {
|
|
this.min = Math.min(this.min, other.min);
|
|
this.max = Math.max(this.max, other.max);
|
|
this.sum += other.sum;
|
|
this.count += other.count;
|
|
return this;
|
|
}
|
|
|
|
public void add(int value) {
|
|
this.min = Math.min(this.min, value);
|
|
this.max = Math.max(this.max, value);
|
|
this.sum += value;
|
|
this.count++;
|
|
}
|
|
|
|
public void format(StringBuilder out) {
|
|
out.append(station()).append('=')
|
|
.append(min / 10.0).append('/')
|
|
.append(Math.round(1.0 * sum / count) / 10.0).append('/')
|
|
.append(max / 10.0);
|
|
}
|
|
}
|
|
|
|
public static void main(String[] args) throws InterruptedException {
|
|
File inputFile = new File(FILE);
|
|
ParsingTask parsingTask = new ParsingTask(inputFile, 0, inputFile.length());
|
|
AggregatedMeasurement[] results = parsingTask.fork().join().toArray();
|
|
System.out.println(formatResult(results));
|
|
}
|
|
|
|
private static String formatResult(AggregatedMeasurement[] result) {
|
|
Arrays.sort(result, Comparator.comparing(AggregatedMeasurement::station));
|
|
StringBuilder out = new StringBuilder().append('{');
|
|
for (AggregatedMeasurement item : result) {
|
|
item.format(out);
|
|
out.append(", ");
|
|
}
|
|
if (out.length() > 2) {
|
|
out.setLength(out.length() - 2);
|
|
}
|
|
return out.append('}').toString();
|
|
}
|
|
|
|
private static class ParsingTask extends RecursiveTask<MyHashMap> {
|
|
private static final int SPLIT_FACTOR = 4;
|
|
private static final int MIN_TASK_SIZE = 256 * 1024;
|
|
private final File file;
|
|
private final long startPosition;
|
|
private final long endPosition;
|
|
|
|
private ParsingTask(File file, long startPosition, long endPosition) {
|
|
this.file = file;
|
|
this.startPosition = startPosition;
|
|
this.endPosition = endPosition;
|
|
}
|
|
|
|
@Override
|
|
protected MyHashMap compute() {
|
|
long size = endPosition - startPosition;
|
|
if (size <= MIN_TASK_SIZE || size < file.length() / Runtime.getRuntime().availableProcessors() / SPLIT_FACTOR) {
|
|
return doCompute();
|
|
}
|
|
var firstHalf = new ParsingTask(file, startPosition, (startPosition + endPosition) / 2).fork();
|
|
var secondHalf = new ParsingTask(file, (startPosition + endPosition) / 2, endPosition).fork();
|
|
var firstHalfResults = firstHalf.join();
|
|
var secondHalfResults = secondHalf.join();
|
|
firstHalfResults.merge(secondHalfResults);
|
|
return firstHalfResults;
|
|
}
|
|
|
|
private MyHashMap doCompute() {
|
|
try (RandomAccessFile raf = new RandomAccessFile(file, "r")) {
|
|
return new ParsingRoutine(startPosition, endPosition).parse(raf);
|
|
}
|
|
catch (IOException e) {
|
|
throw new UncheckedIOException(e);
|
|
}
|
|
}
|
|
}
|
|
|
|
private static class ParsingRoutine {
|
|
private final MyHashMap result = new MyHashMap(2048);
|
|
private final byte[] partialContentBuff = new byte[128];
|
|
private final ByteArrayWrapper reusableWrapper = new ByteArrayWrapper(partialContentBuff, 0, 0, 0);
|
|
private int partialSize;
|
|
private final long startPosition;
|
|
private final long endPosition;
|
|
|
|
private ParsingRoutine(long startPosition, long endPosition) {
|
|
this.startPosition = startPosition;
|
|
this.endPosition = endPosition;
|
|
}
|
|
|
|
private boolean hasPartialContent() {
|
|
return partialSize > 0;
|
|
}
|
|
|
|
MyHashMap parse(RandomAccessFile raf) throws IOException {
|
|
byte[] buffer = new byte[BUFFER_SIZE];
|
|
long offset = findStartOffset(raf, buffer);
|
|
boolean readMore = offset <= endPosition;
|
|
while (readMore) {
|
|
raf.seek(offset);
|
|
int length = raf.read(buffer);
|
|
if (length == -1) {
|
|
break;
|
|
}
|
|
int bufStart = 0;
|
|
if (hasPartialContent()) {
|
|
int idxOfLf = indexOf(buffer, '\n', 0, length);
|
|
if (idxOfLf >= 0) {
|
|
bufStart = idxOfLf + 1;
|
|
System.arraycopy(buffer, 0, partialContentBuff, partialSize, bufStart);
|
|
int toProcess = partialSize + bufStart;
|
|
partialSize = 0;
|
|
processPart(offset - toProcess, partialContentBuff, 0, toProcess);
|
|
}
|
|
else {
|
|
System.arraycopy(buffer, 0, partialContentBuff, partialSize, bufStart);
|
|
offset += length;
|
|
continue;
|
|
}
|
|
}
|
|
readMore = processPart(offset, buffer, bufStart, length);
|
|
offset += length;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
long findStartOffset(RandomAccessFile raf, byte[] buffer) throws IOException {
|
|
if (startPosition == 0) {
|
|
return 0;
|
|
}
|
|
long offset = startPosition - 1;
|
|
int length = 0;
|
|
int idxOfLf = -1;
|
|
while (offset < endPosition) {
|
|
raf.seek(offset);
|
|
length = raf.read(buffer);
|
|
if (length == -1) {
|
|
throw new IllegalStateException("No content read on position " + offset);
|
|
}
|
|
offset += length;
|
|
idxOfLf = indexOf(buffer, '\n', 0, length);
|
|
if (idxOfLf >= 0) {
|
|
break;
|
|
}
|
|
}
|
|
if (offset > startPosition) {
|
|
int start = idxOfLf + 1;
|
|
processPart(offset + start - length, buffer, start, length);
|
|
}
|
|
return offset;
|
|
}
|
|
|
|
boolean processPart(long position, byte[] buf, int start, int end) {
|
|
ByteArrayWrapper key = reusableWrapper;
|
|
key.content = buf;
|
|
while (position <= endPosition) {
|
|
int idxOfSemicolon = indexOf(buf, ';', start, end);
|
|
if (idxOfSemicolon >= 0) {
|
|
int idxOfLf = indexOf(buf, '\n', idxOfSemicolon + 2, Math.max(idxOfSemicolon + 2, end));
|
|
if (idxOfLf >= 0) {
|
|
key.start = start;
|
|
key.length = idxOfSemicolon - start;
|
|
var aggregatedMeasurement = result.getOrCreate(key);
|
|
aggregatedMeasurement.add(parseInt(buf, idxOfSemicolon + 1, idxOfLf - 1));
|
|
int prevStart = start;
|
|
start = idxOfLf + 1;
|
|
position += start - prevStart;
|
|
continue;
|
|
}
|
|
}
|
|
partialSize = end - start;
|
|
if (partialSize > 0) {
|
|
System.arraycopy(buf, start, partialContentBuff, 0, partialSize);
|
|
}
|
|
break;
|
|
}
|
|
return hasPartialContent() || position < endPosition;
|
|
}
|
|
|
|
private int parseInt(byte[] buf, int from, int toIncl) {
|
|
int mul = 1;
|
|
if (buf[from] == '-') {
|
|
mul = -1;
|
|
++from;
|
|
}
|
|
int res = buf[toIncl] - '0';
|
|
int dec = 10;
|
|
for (int i = toIncl - 2; i >= from; --i) {
|
|
res += (buf[i] - '0') * dec;
|
|
dec *= 10;
|
|
}
|
|
return mul * res;
|
|
}
|
|
}
|
|
|
|
private static final MethodHandle indexOfMH;
|
|
private static final MethodHandle vectorizedHashCodeMH;
|
|
private static final MethodHandle mismatchMH;
|
|
|
|
static {
|
|
try {
|
|
Class<?> stringLatin1 = Class.forName("java.lang.StringLatin1");
|
|
var lookup = MethodHandles.privateLookupIn(stringLatin1, MethodHandles.lookup());
|
|
// int indexOf(byte[] value, int ch, int fromIndex, int toIndex)
|
|
indexOfMH = lookup.findStatic(stringLatin1, "indexOf",
|
|
MethodType.methodType(int.class, byte[].class, int.class, int.class, int.class));
|
|
|
|
Class<?> arraysSupport = Class.forName("jdk.internal.util.ArraysSupport");
|
|
lookup = MethodHandles.privateLookupIn(arraysSupport, MethodHandles.lookup());
|
|
// int vectorizedHashCode(Object array, int fromIndex, int length, int initialValue, int basicType)
|
|
vectorizedHashCodeMH = lookup.findStatic(arraysSupport, "vectorizedHashCode",
|
|
MethodType.methodType(int.class, Object.class, int.class, int.class, int.class, int.class));
|
|
lookup = MethodHandles.privateLookupIn(arraysSupport, MethodHandles.lookup());
|
|
// int mismatch(byte[] a, int aFromIndex, byte[] b, int bFromIndex, int length)
|
|
mismatchMH = lookup.findStatic(arraysSupport, "mismatch",
|
|
MethodType.methodType(int.class, byte[].class, int.class, byte[].class, int.class, int.class));
|
|
}
|
|
catch (Exception e) {
|
|
throw new Error(e);
|
|
}
|
|
}
|
|
|
|
static int indexOf(byte[] src, int ch, int fromIndex, int toIndex) {
|
|
try {
|
|
return (int) indexOfMH.invoke(src, ch, fromIndex, toIndex);
|
|
}
|
|
catch (Throwable e) {
|
|
throw new Error(e);
|
|
}
|
|
}
|
|
|
|
static int vectorizedHashCode(byte[] array, int fromIndex, int length) {
|
|
try {
|
|
return (int) vectorizedHashCodeMH.invoke((Object) array, fromIndex, length, 0, 8);
|
|
}
|
|
catch (Throwable e) {
|
|
throw new Error(e);
|
|
}
|
|
}
|
|
|
|
static boolean arraysEqual(byte[] a, int aFromIndex, byte[] b, int bFromIndex, int length) {
|
|
try {
|
|
return ((int) mismatchMH.invoke(a, aFromIndex, b, bFromIndex, length)) < 0;
|
|
}
|
|
catch (Throwable e) {
|
|
throw new Error(e);
|
|
}
|
|
}
|
|
|
|
private static class ByteArrayWrapper {
|
|
private byte[] content;
|
|
private int start;
|
|
private int length;
|
|
int hashCodeCached;
|
|
|
|
ByteArrayWrapper(byte[] content, int start, int length, int hashCodeCached) {
|
|
this.content = content;
|
|
this.start = start;
|
|
this.length = length;
|
|
this.hashCodeCached = hashCodeCached;
|
|
}
|
|
|
|
int calculateHashCode() {
|
|
return this.hashCodeCached = vectorizedHashCode(content, start, length);
|
|
}
|
|
|
|
@Override
|
|
public boolean equals(Object obj) {
|
|
return obj instanceof ByteArrayWrapper bw && isEqualTo(bw);
|
|
}
|
|
|
|
boolean isEqualTo(ByteArrayWrapper other) {
|
|
return length == other.length && arraysEqual(content, start, other.content, other.start, length);
|
|
}
|
|
|
|
@Override
|
|
public int hashCode() {
|
|
return hashCodeCached;
|
|
}
|
|
|
|
@Override
|
|
public String toString() {
|
|
return new String(content, start, length, StandardCharsets.UTF_8);
|
|
}
|
|
|
|
ByteArrayWrapper copy() {
|
|
byte[] copy = Arrays.copyOfRange(content, start, start + length);
|
|
return new ByteArrayWrapper(copy, 0, copy.length, hashCodeCached);
|
|
}
|
|
|
|
}
|
|
|
|
private static class MyHashMap {
|
|
private static final int INVERSE_LOAD_FACTOR = 2;
|
|
private AggregatedMeasurement[] data;
|
|
private int size;
|
|
|
|
MyHashMap(int initialCapacity) {
|
|
this.data = new AggregatedMeasurement[initialCapacity];
|
|
}
|
|
|
|
private void grow() {
|
|
AggregatedMeasurement[] oldData = data;
|
|
AggregatedMeasurement[] newData = new AggregatedMeasurement[oldData.length * 2];
|
|
for (AggregatedMeasurement measurement : oldData) {
|
|
if (measurement != null) {
|
|
put(measurement, newData);
|
|
}
|
|
}
|
|
this.data = newData;
|
|
}
|
|
|
|
private void put(AggregatedMeasurement value, AggregatedMeasurement[] data) {
|
|
int length = data.length;
|
|
int hashCode = value.hashCode();
|
|
int idx = length & (hashCode ^ (hashCode >> 16));
|
|
for (int i = idx; i < length; ++i) {
|
|
if (data[i] == null) {
|
|
data[i] = value;
|
|
return;
|
|
}
|
|
}
|
|
for (int i = 0; i < idx; ++i) {
|
|
if (data[i] == null) {
|
|
data[i] = value;
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
AggregatedMeasurement getOrCreate(ByteArrayWrapper key) {
|
|
AggregatedMeasurement[] data = this.data;
|
|
int length = data.length;
|
|
int hashCode = key.calculateHashCode();
|
|
int idx = (length - 1) & (hashCode ^ (hashCode >> 16));
|
|
AggregatedMeasurement result = doGetOrCreate(key, idx, length, data);
|
|
return result != null ? result : doGetOrCreate(key, 0, idx, data);
|
|
}
|
|
|
|
private AggregatedMeasurement doGetOrCreate(ByteArrayWrapper key, int from, int to, AggregatedMeasurement[] data) {
|
|
for (int i = from; i < to; ++i) {
|
|
AggregatedMeasurement item = data[i];
|
|
if (item != null) {
|
|
if (item.station.isEqualTo(key)) {
|
|
return item;
|
|
}
|
|
}
|
|
else {
|
|
AggregatedMeasurement result = new AggregatedMeasurement(key.copy());
|
|
++size;
|
|
if (size * INVERSE_LOAD_FACTOR >= data.length) {
|
|
grow();
|
|
put(result, this.data);
|
|
}
|
|
else {
|
|
data[i] = result;
|
|
}
|
|
return result;
|
|
}
|
|
}
|
|
return null;
|
|
}
|
|
|
|
void merge(MyHashMap other) {
|
|
for (AggregatedMeasurement measurement : other.data) {
|
|
if (measurement != null) {
|
|
merge(measurement);
|
|
}
|
|
}
|
|
}
|
|
|
|
private boolean merge(AggregatedMeasurement value) {
|
|
AggregatedMeasurement[] data = this.data;
|
|
int length = data.length;
|
|
ByteArrayWrapper key = value.station;
|
|
int hashCode = key.hashCode();
|
|
int idx = (length - 1) & (hashCode ^ (hashCode >> 16));
|
|
return doMerge(key, value, idx, length, data) || doMerge(key, value, 0, idx, data);
|
|
}
|
|
|
|
private boolean doMerge(ByteArrayWrapper key, AggregatedMeasurement value, int from, int to, AggregatedMeasurement[] data) {
|
|
for (int i = from; i < to; ++i) {
|
|
AggregatedMeasurement item = data[i];
|
|
if (item != null) {
|
|
if (item.station.isEqualTo(key)) {
|
|
item.merge(value);
|
|
return true;
|
|
}
|
|
}
|
|
else {
|
|
++size;
|
|
if (size * INVERSE_LOAD_FACTOR >= data.length) {
|
|
grow();
|
|
put(value, this.data);
|
|
}
|
|
else {
|
|
data[i] = value;
|
|
}
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
AggregatedMeasurement[] toArray() {
|
|
AggregatedMeasurement[] result = new AggregatedMeasurement[size];
|
|
int i = 0;
|
|
for (AggregatedMeasurement measurement : data) {
|
|
if (measurement != null) {
|
|
result[i++] = measurement;
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
}
|
|
|
|
}
|