diff --git a/calculate_average_kuduwa-keshavram.sh b/calculate_average_kuduwa-keshavram.sh index 904c8db..33941d3 100755 --- a/calculate_average_kuduwa-keshavram.sh +++ b/calculate_average_kuduwa-keshavram.sh @@ -16,5 +16,5 @@ # -JAVA_OPTS="" +JAVA_OPTS="--enable-preview" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_kuduwa_keshavram diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_kuduwa_keshavram.java b/src/main/java/dev/morling/onebrc/CalculateAverage_kuduwa_keshavram.java index c611166..68ace02 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_kuduwa_keshavram.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_kuduwa_keshavram.java @@ -17,75 +17,77 @@ package dev.morling.onebrc; import java.io.File; import java.io.IOException; -import java.io.RandomAccessFile; -import java.nio.ByteOrder; -import java.nio.MappedByteBuffer; +import java.lang.foreign.Arena; +import java.lang.reflect.Field; import java.nio.channels.FileChannel; import java.nio.channels.FileChannel.MapMode; -import java.nio.file.Files; -import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.Arrays; -import java.util.List; -import java.util.Objects; +import java.util.Iterator; +import java.util.Spliterator; +import java.util.Spliterators; import java.util.TreeMap; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import sun.misc.Unsafe; public class CalculateAverage_kuduwa_keshavram { private static final String FILE = "./measurements.txt"; + private static final Unsafe UNSAFE = initUnsafe(); + + private static Unsafe initUnsafe() { + try { + final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + return (Unsafe) theUnsafe.get(Unsafe.class); + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } public static void main(String[] args) throws IOException, InterruptedException { - TreeMap resultMap = getFileSegments(new File(FILE)).stream() - .parallel() - .map( + TreeMap resultMap = getFileSegments(new File(FILE)) + .flatMap( segment -> { - final Measurement[][] measurements = new Measurement[1024 * 128][3]; - try (FileChannel fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ)) { - MappedByteBuffer byteBuffer = fileChannel.map( - MapMode.READ_ONLY, segment.start, segment.end - segment.start); - byteBuffer.order(ByteOrder.nativeOrder()); - while (byteBuffer.hasRemaining()) { - byte[] city = new byte[100]; - byte b; - int hash = 0; - int i = 0; - while ((b = byteBuffer.get()) != 59) { - hash = 31 * hash + b; - city[i++] = b; - } - - byte[] newCity = new byte[i]; - System.arraycopy(city, 0, newCity, 0, i); - int measurement = 0; - boolean negative = false; - while ((b = byteBuffer.get()) != 10) { - if (b == 45) { - negative = true; - } - else if (b == 46) { - // skip - } - else { - final int n = b - '0'; - measurement = measurement * 10 + n; - } - } - putOrMerge( - measurements, - new Measurement( - hash, newCity, negative ? measurement * -1 : measurement)); + Result result = new Result(); + while (segment.start < segment.end) { + byte[] city = new byte[100]; + byte b; + int hash = 0; + int i = 0; + while ((b = UNSAFE.getByte(segment.start++)) != 59) { + hash = 31 * hash + b; + city[i++] = b; } + + byte[] newCity = new byte[i]; + System.arraycopy(city, 0, newCity, 0, i); + int measurement = 0; + boolean negative = false; + while ((b = UNSAFE.getByte(segment.start++)) != 10) { + if (b == 45) { + negative = true; + } + else if (b == 46) { + // skip + } + else { + final int n = b - '0'; + measurement = measurement * 10 + n; + } + } + putOrMerge( + result, + new Measurement(hash, newCity, negative ? measurement * -1 : measurement)); } - catch (IOException e) { - throw new RuntimeException(e); - } - return measurements; + Iterator iterator = getMeasurementIterator(result); + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize(iterator, Spliterator.NONNULL), true); }) - .flatMap(measurements -> Arrays.stream(measurements).flatMap(Arrays::stream)) - .filter(Objects::nonNull) .collect( Collectors.toMap( measurement -> new String(measurement.city), @@ -99,13 +101,48 @@ public class CalculateAverage_kuduwa_keshavram { System.out.println(resultMap); } - private static void putOrMerge(Measurement[][] measurements, Measurement measurement) { - int index = measurement.hash & (measurements.length - 1); - Measurement[] existing = measurements[index]; + private static Iterator getMeasurementIterator(Result result) { + return new Iterator<>() { + final int uniqueIndex = result.uniqueIndex; + final int[] indexArray = result.indexArray; + final Measurement[][] measurements = result.measurements; + + int i = 0; + int j = 0; + + @Override + public boolean hasNext() { + return i < uniqueIndex; + } + + @Override + public Measurement next() { + Measurement measurement = measurements[indexArray[i]][j++]; + if (measurements[indexArray[i]][j] == null) { + i++; + j = 0; + } + return measurement; + } + }; + } + + static class Result { + final Measurement[][] measurements = new Measurement[1024 * 128][3]; + final int[] indexArray = new int[10_000]; + int uniqueIndex = 0; + } + + private static void putOrMerge(Result result, Measurement measurement) { + int index = measurement.hash & (result.measurements.length - 1); + Measurement[] existing = result.measurements[index]; for (int i = 0; i < existing.length; i++) { Measurement existingMeasurement = existing[i]; if (existingMeasurement == null) { - measurements[index][i] = measurement; + result.measurements[index][i] = measurement; + if (i == 0) { + result.indexArray[result.uniqueIndex++] = index; + } return; } if (equals(existingMeasurement.city, measurement.city)) { @@ -124,13 +161,20 @@ public class CalculateAverage_kuduwa_keshavram { return true; } - private record FileSegment(long start, long end) { + private static final class FileSegment { + long start; + long end; + + private FileSegment(long start, long end) { + this.start = start; + this.end = end; + } } private static final class Measurement { - private int hash; - private byte[] city; + private final int hash; + private final byte[] city; int min; int max; @@ -158,45 +202,28 @@ public class CalculateAverage_kuduwa_keshavram { } } - private static List getFileSegments(final File file) throws IOException { + private static Stream getFileSegments(final File file) throws IOException { final int numberOfSegments = Runtime.getRuntime().availableProcessors() * 4; - final long fileSize = file.length(); - final long segmentSize = fileSize / numberOfSegments; - if (segmentSize < 1000) { - return List.of(new FileSegment(0, fileSize)); - } - - try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) { - int lastSegment = numberOfSegments - 1; - return IntStream.range(0, numberOfSegments) - .mapToObj( - i -> { - long segStart = i * segmentSize; - long segEnd = (i == lastSegment) ? fileSize : segStart + segmentSize; - try { - segStart = findSegment(i, 0, randomAccessFile, segStart, segEnd); - segEnd = findSegment(i, lastSegment, randomAccessFile, segEnd, fileSize); - } - catch (IOException e) { - throw new RuntimeException(e); - } - return new FileSegment(segStart, segEnd); - }) - .toList(); - } - } - - private static long findSegment( - final int i, final int skipSegment, RandomAccessFile raf, long location, final long fileSize) - throws IOException { - if (i != skipSegment) { - raf.seek(location); - while (location < fileSize) { - location++; - if (raf.read() == '\n') - return location; + final long[] chunks = new long[numberOfSegments + 1]; + try (var fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ)) { + final long fileSize = fileChannel.size(); + final long segmentSize = (fileSize + numberOfSegments - 1) / numberOfSegments; + final long mappedAddress = fileChannel.map(MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); + chunks[0] = mappedAddress; + final long endAddress = mappedAddress + fileSize; + for (int i = 1; i < numberOfSegments; ++i) { + long chunkAddress = mappedAddress + i * segmentSize; + // Align to first row start. + while (chunkAddress < endAddress && UNSAFE.getByte(chunkAddress++) != '\n') { + // nop + } + chunks[i] = Math.min(chunkAddress, endAddress); } + chunks[numberOfSegments] = endAddress; } - return location; + return IntStream.range(0, chunks.length - 1) + .mapToObj(chunkIndex -> new FileSegment(chunks[chunkIndex], chunks[chunkIndex + 1])) + .parallel(); } + }