Optimised code to iterate over non-null measurements (#444)

Co-authored-by: Keshavram Kuduwa <keshavram.kuduwa@apptware.com>
This commit is contained in:
Keshavram Kuduwa 2024-01-17 02:32:26 +05:30 committed by GitHub
parent c080143ca8
commit b1e6a120a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 124 additions and 97 deletions

View File

@ -16,5 +16,5 @@
#
JAVA_OPTS=""
JAVA_OPTS="--enable-preview"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_kuduwa_keshavram

View File

@ -17,42 +17,49 @@ package dev.morling.onebrc;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Iterator;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import sun.misc.Unsafe;
public class CalculateAverage_kuduwa_keshavram {
private static final String FILE = "./measurements.txt";
private static final Unsafe UNSAFE = initUnsafe();
private static Unsafe initUnsafe() {
try {
final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
return (Unsafe) theUnsafe.get(Unsafe.class);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) throws IOException, InterruptedException {
TreeMap<String, Measurement> resultMap = getFileSegments(new File(FILE)).stream()
.parallel()
.map(
TreeMap<String, Measurement> resultMap = getFileSegments(new File(FILE))
.flatMap(
segment -> {
final Measurement[][] measurements = new Measurement[1024 * 128][3];
try (FileChannel fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ)) {
MappedByteBuffer byteBuffer = fileChannel.map(
MapMode.READ_ONLY, segment.start, segment.end - segment.start);
byteBuffer.order(ByteOrder.nativeOrder());
while (byteBuffer.hasRemaining()) {
Result result = new Result();
while (segment.start < segment.end) {
byte[] city = new byte[100];
byte b;
int hash = 0;
int i = 0;
while ((b = byteBuffer.get()) != 59) {
while ((b = UNSAFE.getByte(segment.start++)) != 59) {
hash = 31 * hash + b;
city[i++] = b;
}
@ -61,7 +68,7 @@ public class CalculateAverage_kuduwa_keshavram {
System.arraycopy(city, 0, newCity, 0, i);
int measurement = 0;
boolean negative = false;
while ((b = byteBuffer.get()) != 10) {
while ((b = UNSAFE.getByte(segment.start++)) != 10) {
if (b == 45) {
negative = true;
}
@ -74,18 +81,13 @@ public class CalculateAverage_kuduwa_keshavram {
}
}
putOrMerge(
measurements,
new Measurement(
hash, newCity, negative ? measurement * -1 : measurement));
result,
new Measurement(hash, newCity, negative ? measurement * -1 : measurement));
}
}
catch (IOException e) {
throw new RuntimeException(e);
}
return measurements;
Iterator<Measurement> iterator = getMeasurementIterator(result);
return StreamSupport.stream(
Spliterators.spliteratorUnknownSize(iterator, Spliterator.NONNULL), true);
})
.flatMap(measurements -> Arrays.stream(measurements).flatMap(Arrays::stream))
.filter(Objects::nonNull)
.collect(
Collectors.toMap(
measurement -> new String(measurement.city),
@ -99,13 +101,48 @@ public class CalculateAverage_kuduwa_keshavram {
System.out.println(resultMap);
}
private static void putOrMerge(Measurement[][] measurements, Measurement measurement) {
int index = measurement.hash & (measurements.length - 1);
Measurement[] existing = measurements[index];
private static Iterator<Measurement> getMeasurementIterator(Result result) {
return new Iterator<>() {
final int uniqueIndex = result.uniqueIndex;
final int[] indexArray = result.indexArray;
final Measurement[][] measurements = result.measurements;
int i = 0;
int j = 0;
@Override
public boolean hasNext() {
return i < uniqueIndex;
}
@Override
public Measurement next() {
Measurement measurement = measurements[indexArray[i]][j++];
if (measurements[indexArray[i]][j] == null) {
i++;
j = 0;
}
return measurement;
}
};
}
static class Result {
final Measurement[][] measurements = new Measurement[1024 * 128][3];
final int[] indexArray = new int[10_000];
int uniqueIndex = 0;
}
private static void putOrMerge(Result result, Measurement measurement) {
int index = measurement.hash & (result.measurements.length - 1);
Measurement[] existing = result.measurements[index];
for (int i = 0; i < existing.length; i++) {
Measurement existingMeasurement = existing[i];
if (existingMeasurement == null) {
measurements[index][i] = measurement;
result.measurements[index][i] = measurement;
if (i == 0) {
result.indexArray[result.uniqueIndex++] = index;
}
return;
}
if (equals(existingMeasurement.city, measurement.city)) {
@ -124,13 +161,20 @@ public class CalculateAverage_kuduwa_keshavram {
return true;
}
private record FileSegment(long start, long end) {
private static final class FileSegment {
long start;
long end;
private FileSegment(long start, long end) {
this.start = start;
this.end = end;
}
}
private static final class Measurement {
private int hash;
private byte[] city;
private final int hash;
private final byte[] city;
int min;
int max;
@ -158,45 +202,28 @@ public class CalculateAverage_kuduwa_keshavram {
}
}
private static List<FileSegment> getFileSegments(final File file) throws IOException {
private static Stream<FileSegment> getFileSegments(final File file) throws IOException {
final int numberOfSegments = Runtime.getRuntime().availableProcessors() * 4;
final long fileSize = file.length();
final long segmentSize = fileSize / numberOfSegments;
if (segmentSize < 1000) {
return List.of(new FileSegment(0, fileSize));
final long[] chunks = new long[numberOfSegments + 1];
try (var fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ)) {
final long fileSize = fileChannel.size();
final long segmentSize = (fileSize + numberOfSegments - 1) / numberOfSegments;
final long mappedAddress = fileChannel.map(MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
chunks[0] = mappedAddress;
final long endAddress = mappedAddress + fileSize;
for (int i = 1; i < numberOfSegments; ++i) {
long chunkAddress = mappedAddress + i * segmentSize;
// Align to first row start.
while (chunkAddress < endAddress && UNSAFE.getByte(chunkAddress++) != '\n') {
// nop
}
chunks[i] = Math.min(chunkAddress, endAddress);
}
chunks[numberOfSegments] = endAddress;
}
return IntStream.range(0, chunks.length - 1)
.mapToObj(chunkIndex -> new FileSegment(chunks[chunkIndex], chunks[chunkIndex + 1]))
.parallel();
}
try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) {
int lastSegment = numberOfSegments - 1;
return IntStream.range(0, numberOfSegments)
.mapToObj(
i -> {
long segStart = i * segmentSize;
long segEnd = (i == lastSegment) ? fileSize : segStart + segmentSize;
try {
segStart = findSegment(i, 0, randomAccessFile, segStart, segEnd);
segEnd = findSegment(i, lastSegment, randomAccessFile, segEnd, fileSize);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return new FileSegment(segStart, segEnd);
})
.toList();
}
}
private static long findSegment(
final int i, final int skipSegment, RandomAccessFile raf, long location, final long fileSize)
throws IOException {
if (i != skipSegment) {
raf.seek(location);
while (location < fileSize) {
location++;
if (raf.read() == '\n')
return location;
}
}
return location;
}
}