Optimised code to iterate over non-null measurements (#444)

Co-authored-by: Keshavram Kuduwa <keshavram.kuduwa@apptware.com>
This commit is contained in:
Keshavram Kuduwa 2024-01-17 02:32:26 +05:30 committed by GitHub
parent c080143ca8
commit b1e6a120a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 124 additions and 97 deletions

View File

@ -16,5 +16,5 @@
# #
JAVA_OPTS="" JAVA_OPTS="--enable-preview"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_kuduwa_keshavram java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_kuduwa_keshavram

View File

@ -17,75 +17,77 @@ package dev.morling.onebrc;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.RandomAccessFile; import java.lang.foreign.Arena;
import java.nio.ByteOrder; import java.lang.reflect.Field;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode; import java.nio.channels.FileChannel.MapMode;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption; import java.nio.file.StandardOpenOption;
import java.util.Arrays; import java.util.Iterator;
import java.util.List; import java.util.Spliterator;
import java.util.Objects; import java.util.Spliterators;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import sun.misc.Unsafe;
public class CalculateAverage_kuduwa_keshavram { public class CalculateAverage_kuduwa_keshavram {
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static final Unsafe UNSAFE = initUnsafe();
private static Unsafe initUnsafe() {
try {
final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
return (Unsafe) theUnsafe.get(Unsafe.class);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) throws IOException, InterruptedException { public static void main(String[] args) throws IOException, InterruptedException {
TreeMap<String, Measurement> resultMap = getFileSegments(new File(FILE)).stream() TreeMap<String, Measurement> resultMap = getFileSegments(new File(FILE))
.parallel() .flatMap(
.map(
segment -> { segment -> {
final Measurement[][] measurements = new Measurement[1024 * 128][3]; Result result = new Result();
try (FileChannel fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ)) { while (segment.start < segment.end) {
MappedByteBuffer byteBuffer = fileChannel.map( byte[] city = new byte[100];
MapMode.READ_ONLY, segment.start, segment.end - segment.start); byte b;
byteBuffer.order(ByteOrder.nativeOrder()); int hash = 0;
while (byteBuffer.hasRemaining()) { int i = 0;
byte[] city = new byte[100]; while ((b = UNSAFE.getByte(segment.start++)) != 59) {
byte b; hash = 31 * hash + b;
int hash = 0; city[i++] = b;
int i = 0;
while ((b = byteBuffer.get()) != 59) {
hash = 31 * hash + b;
city[i++] = b;
}
byte[] newCity = new byte[i];
System.arraycopy(city, 0, newCity, 0, i);
int measurement = 0;
boolean negative = false;
while ((b = byteBuffer.get()) != 10) {
if (b == 45) {
negative = true;
}
else if (b == 46) {
// skip
}
else {
final int n = b - '0';
measurement = measurement * 10 + n;
}
}
putOrMerge(
measurements,
new Measurement(
hash, newCity, negative ? measurement * -1 : measurement));
} }
byte[] newCity = new byte[i];
System.arraycopy(city, 0, newCity, 0, i);
int measurement = 0;
boolean negative = false;
while ((b = UNSAFE.getByte(segment.start++)) != 10) {
if (b == 45) {
negative = true;
}
else if (b == 46) {
// skip
}
else {
final int n = b - '0';
measurement = measurement * 10 + n;
}
}
putOrMerge(
result,
new Measurement(hash, newCity, negative ? measurement * -1 : measurement));
} }
catch (IOException e) { Iterator<Measurement> iterator = getMeasurementIterator(result);
throw new RuntimeException(e); return StreamSupport.stream(
} Spliterators.spliteratorUnknownSize(iterator, Spliterator.NONNULL), true);
return measurements;
}) })
.flatMap(measurements -> Arrays.stream(measurements).flatMap(Arrays::stream))
.filter(Objects::nonNull)
.collect( .collect(
Collectors.toMap( Collectors.toMap(
measurement -> new String(measurement.city), measurement -> new String(measurement.city),
@ -99,13 +101,48 @@ public class CalculateAverage_kuduwa_keshavram {
System.out.println(resultMap); System.out.println(resultMap);
} }
private static void putOrMerge(Measurement[][] measurements, Measurement measurement) { private static Iterator<Measurement> getMeasurementIterator(Result result) {
int index = measurement.hash & (measurements.length - 1); return new Iterator<>() {
Measurement[] existing = measurements[index]; final int uniqueIndex = result.uniqueIndex;
final int[] indexArray = result.indexArray;
final Measurement[][] measurements = result.measurements;
int i = 0;
int j = 0;
@Override
public boolean hasNext() {
return i < uniqueIndex;
}
@Override
public Measurement next() {
Measurement measurement = measurements[indexArray[i]][j++];
if (measurements[indexArray[i]][j] == null) {
i++;
j = 0;
}
return measurement;
}
};
}
static class Result {
final Measurement[][] measurements = new Measurement[1024 * 128][3];
final int[] indexArray = new int[10_000];
int uniqueIndex = 0;
}
private static void putOrMerge(Result result, Measurement measurement) {
int index = measurement.hash & (result.measurements.length - 1);
Measurement[] existing = result.measurements[index];
for (int i = 0; i < existing.length; i++) { for (int i = 0; i < existing.length; i++) {
Measurement existingMeasurement = existing[i]; Measurement existingMeasurement = existing[i];
if (existingMeasurement == null) { if (existingMeasurement == null) {
measurements[index][i] = measurement; result.measurements[index][i] = measurement;
if (i == 0) {
result.indexArray[result.uniqueIndex++] = index;
}
return; return;
} }
if (equals(existingMeasurement.city, measurement.city)) { if (equals(existingMeasurement.city, measurement.city)) {
@ -124,13 +161,20 @@ public class CalculateAverage_kuduwa_keshavram {
return true; return true;
} }
private record FileSegment(long start, long end) { private static final class FileSegment {
long start;
long end;
private FileSegment(long start, long end) {
this.start = start;
this.end = end;
}
} }
private static final class Measurement { private static final class Measurement {
private int hash; private final int hash;
private byte[] city; private final byte[] city;
int min; int min;
int max; int max;
@ -158,45 +202,28 @@ public class CalculateAverage_kuduwa_keshavram {
} }
} }
private static List<FileSegment> getFileSegments(final File file) throws IOException { private static Stream<FileSegment> getFileSegments(final File file) throws IOException {
final int numberOfSegments = Runtime.getRuntime().availableProcessors() * 4; final int numberOfSegments = Runtime.getRuntime().availableProcessors() * 4;
final long fileSize = file.length(); final long[] chunks = new long[numberOfSegments + 1];
final long segmentSize = fileSize / numberOfSegments; try (var fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ)) {
if (segmentSize < 1000) { final long fileSize = fileChannel.size();
return List.of(new FileSegment(0, fileSize)); final long segmentSize = (fileSize + numberOfSegments - 1) / numberOfSegments;
} final long mappedAddress = fileChannel.map(MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
chunks[0] = mappedAddress;
try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) { final long endAddress = mappedAddress + fileSize;
int lastSegment = numberOfSegments - 1; for (int i = 1; i < numberOfSegments; ++i) {
return IntStream.range(0, numberOfSegments) long chunkAddress = mappedAddress + i * segmentSize;
.mapToObj( // Align to first row start.
i -> { while (chunkAddress < endAddress && UNSAFE.getByte(chunkAddress++) != '\n') {
long segStart = i * segmentSize; // nop
long segEnd = (i == lastSegment) ? fileSize : segStart + segmentSize; }
try { chunks[i] = Math.min(chunkAddress, endAddress);
segStart = findSegment(i, 0, randomAccessFile, segStart, segEnd);
segEnd = findSegment(i, lastSegment, randomAccessFile, segEnd, fileSize);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return new FileSegment(segStart, segEnd);
})
.toList();
}
}
private static long findSegment(
final int i, final int skipSegment, RandomAccessFile raf, long location, final long fileSize)
throws IOException {
if (i != skipSegment) {
raf.seek(location);
while (location < fileSize) {
location++;
if (raf.read() == '\n')
return location;
} }
chunks[numberOfSegments] = endAddress;
} }
return location; return IntStream.range(0, chunks.length - 1)
.mapToObj(chunkIndex -> new FileSegment(chunks[chunkIndex], chunks[chunkIndex + 1]))
.parallel();
} }
} }