Optimised code to iterate over non-null measurements (#444)
Co-authored-by: Keshavram Kuduwa <keshavram.kuduwa@apptware.com>
This commit is contained in:
		| @@ -17,75 +17,77 @@ package dev.morling.onebrc; | ||||
|  | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.io.RandomAccessFile; | ||||
| import java.nio.ByteOrder; | ||||
| import java.nio.MappedByteBuffer; | ||||
| import java.lang.foreign.Arena; | ||||
| import java.lang.reflect.Field; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.channels.FileChannel.MapMode; | ||||
| import java.nio.file.Files; | ||||
| import java.nio.file.Path; | ||||
| import java.nio.file.StandardOpenOption; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| import java.util.Objects; | ||||
| import java.util.Iterator; | ||||
| import java.util.Spliterator; | ||||
| import java.util.Spliterators; | ||||
| import java.util.TreeMap; | ||||
| import java.util.function.Function; | ||||
| import java.util.stream.Collectors; | ||||
| import java.util.stream.IntStream; | ||||
| import java.util.stream.Stream; | ||||
| import java.util.stream.StreamSupport; | ||||
| import sun.misc.Unsafe; | ||||
|  | ||||
| public class CalculateAverage_kuduwa_keshavram { | ||||
|  | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|     private static final Unsafe UNSAFE = initUnsafe(); | ||||
|  | ||||
|     private static Unsafe initUnsafe() { | ||||
|         try { | ||||
|             final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); | ||||
|             theUnsafe.setAccessible(true); | ||||
|             return (Unsafe) theUnsafe.get(Unsafe.class); | ||||
|         } | ||||
|         catch (NoSuchFieldException | IllegalAccessException e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     public static void main(String[] args) throws IOException, InterruptedException { | ||||
|         TreeMap<String, Measurement> resultMap = getFileSegments(new File(FILE)).stream() | ||||
|                 .parallel() | ||||
|                 .map( | ||||
|         TreeMap<String, Measurement> resultMap = getFileSegments(new File(FILE)) | ||||
|                 .flatMap( | ||||
|                         segment -> { | ||||
|                             final Measurement[][] measurements = new Measurement[1024 * 128][3]; | ||||
|                             try (FileChannel fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ)) { | ||||
|                                 MappedByteBuffer byteBuffer = fileChannel.map( | ||||
|                                         MapMode.READ_ONLY, segment.start, segment.end - segment.start); | ||||
|                                 byteBuffer.order(ByteOrder.nativeOrder()); | ||||
|                                 while (byteBuffer.hasRemaining()) { | ||||
|                                     byte[] city = new byte[100]; | ||||
|                                     byte b; | ||||
|                                     int hash = 0; | ||||
|                                     int i = 0; | ||||
|                                     while ((b = byteBuffer.get()) != 59) { | ||||
|                                         hash = 31 * hash + b; | ||||
|                                         city[i++] = b; | ||||
|                                     } | ||||
|  | ||||
|                                     byte[] newCity = new byte[i]; | ||||
|                                     System.arraycopy(city, 0, newCity, 0, i); | ||||
|                                     int measurement = 0; | ||||
|                                     boolean negative = false; | ||||
|                                     while ((b = byteBuffer.get()) != 10) { | ||||
|                                         if (b == 45) { | ||||
|                                             negative = true; | ||||
|                                         } | ||||
|                                         else if (b == 46) { | ||||
|                                             // skip | ||||
|                                         } | ||||
|                                         else { | ||||
|                                             final int n = b - '0'; | ||||
|                                             measurement = measurement * 10 + n; | ||||
|                                         } | ||||
|                                     } | ||||
|                                     putOrMerge( | ||||
|                                             measurements, | ||||
|                                             new Measurement( | ||||
|                                                     hash, newCity, negative ? measurement * -1 : measurement)); | ||||
|                             Result result = new Result(); | ||||
|                             while (segment.start < segment.end) { | ||||
|                                 byte[] city = new byte[100]; | ||||
|                                 byte b; | ||||
|                                 int hash = 0; | ||||
|                                 int i = 0; | ||||
|                                 while ((b = UNSAFE.getByte(segment.start++)) != 59) { | ||||
|                                     hash = 31 * hash + b; | ||||
|                                     city[i++] = b; | ||||
|                                 } | ||||
|  | ||||
|                                 byte[] newCity = new byte[i]; | ||||
|                                 System.arraycopy(city, 0, newCity, 0, i); | ||||
|                                 int measurement = 0; | ||||
|                                 boolean negative = false; | ||||
|                                 while ((b = UNSAFE.getByte(segment.start++)) != 10) { | ||||
|                                     if (b == 45) { | ||||
|                                         negative = true; | ||||
|                                     } | ||||
|                                     else if (b == 46) { | ||||
|                                         // skip | ||||
|                                     } | ||||
|                                     else { | ||||
|                                         final int n = b - '0'; | ||||
|                                         measurement = measurement * 10 + n; | ||||
|                                     } | ||||
|                                 } | ||||
|                                 putOrMerge( | ||||
|                                         result, | ||||
|                                         new Measurement(hash, newCity, negative ? measurement * -1 : measurement)); | ||||
|                             } | ||||
|                             catch (IOException e) { | ||||
|                                 throw new RuntimeException(e); | ||||
|                             } | ||||
|                             return measurements; | ||||
|                             Iterator<Measurement> iterator = getMeasurementIterator(result); | ||||
|                             return StreamSupport.stream( | ||||
|                                     Spliterators.spliteratorUnknownSize(iterator, Spliterator.NONNULL), true); | ||||
|                         }) | ||||
|                 .flatMap(measurements -> Arrays.stream(measurements).flatMap(Arrays::stream)) | ||||
|                 .filter(Objects::nonNull) | ||||
|                 .collect( | ||||
|                         Collectors.toMap( | ||||
|                                 measurement -> new String(measurement.city), | ||||
| @@ -99,13 +101,48 @@ public class CalculateAverage_kuduwa_keshavram { | ||||
|         System.out.println(resultMap); | ||||
|     } | ||||
|  | ||||
|     private static void putOrMerge(Measurement[][] measurements, Measurement measurement) { | ||||
|         int index = measurement.hash & (measurements.length - 1); | ||||
|         Measurement[] existing = measurements[index]; | ||||
|     private static Iterator<Measurement> getMeasurementIterator(Result result) { | ||||
|         return new Iterator<>() { | ||||
|             final int uniqueIndex = result.uniqueIndex; | ||||
|             final int[] indexArray = result.indexArray; | ||||
|             final Measurement[][] measurements = result.measurements; | ||||
|  | ||||
|             int i = 0; | ||||
|             int j = 0; | ||||
|  | ||||
|             @Override | ||||
|             public boolean hasNext() { | ||||
|                 return i < uniqueIndex; | ||||
|             } | ||||
|  | ||||
|             @Override | ||||
|             public Measurement next() { | ||||
|                 Measurement measurement = measurements[indexArray[i]][j++]; | ||||
|                 if (measurements[indexArray[i]][j] == null) { | ||||
|                     i++; | ||||
|                     j = 0; | ||||
|                 } | ||||
|                 return measurement; | ||||
|             } | ||||
|         }; | ||||
|     } | ||||
|  | ||||
|     static class Result { | ||||
|         final Measurement[][] measurements = new Measurement[1024 * 128][3]; | ||||
|         final int[] indexArray = new int[10_000]; | ||||
|         int uniqueIndex = 0; | ||||
|     } | ||||
|  | ||||
|     private static void putOrMerge(Result result, Measurement measurement) { | ||||
|         int index = measurement.hash & (result.measurements.length - 1); | ||||
|         Measurement[] existing = result.measurements[index]; | ||||
|         for (int i = 0; i < existing.length; i++) { | ||||
|             Measurement existingMeasurement = existing[i]; | ||||
|             if (existingMeasurement == null) { | ||||
|                 measurements[index][i] = measurement; | ||||
|                 result.measurements[index][i] = measurement; | ||||
|                 if (i == 0) { | ||||
|                     result.indexArray[result.uniqueIndex++] = index; | ||||
|                 } | ||||
|                 return; | ||||
|             } | ||||
|             if (equals(existingMeasurement.city, measurement.city)) { | ||||
| @@ -124,13 +161,20 @@ public class CalculateAverage_kuduwa_keshavram { | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     private record FileSegment(long start, long end) { | ||||
|     private static final class FileSegment { | ||||
|         long start; | ||||
|         long end; | ||||
|  | ||||
|         private FileSegment(long start, long end) { | ||||
|             this.start = start; | ||||
|             this.end = end; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static final class Measurement { | ||||
|  | ||||
|         private int hash; | ||||
|         private byte[] city; | ||||
|         private final int hash; | ||||
|         private final byte[] city; | ||||
|  | ||||
|         int min; | ||||
|         int max; | ||||
| @@ -158,45 +202,28 @@ public class CalculateAverage_kuduwa_keshavram { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static List<FileSegment> getFileSegments(final File file) throws IOException { | ||||
|     private static Stream<FileSegment> getFileSegments(final File file) throws IOException { | ||||
|         final int numberOfSegments = Runtime.getRuntime().availableProcessors() * 4; | ||||
|         final long fileSize = file.length(); | ||||
|         final long segmentSize = fileSize / numberOfSegments; | ||||
|         if (segmentSize < 1000) { | ||||
|             return List.of(new FileSegment(0, fileSize)); | ||||
|         } | ||||
|  | ||||
|         try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) { | ||||
|             int lastSegment = numberOfSegments - 1; | ||||
|             return IntStream.range(0, numberOfSegments) | ||||
|                     .mapToObj( | ||||
|                             i -> { | ||||
|                                 long segStart = i * segmentSize; | ||||
|                                 long segEnd = (i == lastSegment) ? fileSize : segStart + segmentSize; | ||||
|                                 try { | ||||
|                                     segStart = findSegment(i, 0, randomAccessFile, segStart, segEnd); | ||||
|                                     segEnd = findSegment(i, lastSegment, randomAccessFile, segEnd, fileSize); | ||||
|                                 } | ||||
|                                 catch (IOException e) { | ||||
|                                     throw new RuntimeException(e); | ||||
|                                 } | ||||
|                                 return new FileSegment(segStart, segEnd); | ||||
|                             }) | ||||
|                     .toList(); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static long findSegment( | ||||
|                                     final int i, final int skipSegment, RandomAccessFile raf, long location, final long fileSize) | ||||
|             throws IOException { | ||||
|         if (i != skipSegment) { | ||||
|             raf.seek(location); | ||||
|             while (location < fileSize) { | ||||
|                 location++; | ||||
|                 if (raf.read() == '\n') | ||||
|                     return location; | ||||
|         final long[] chunks = new long[numberOfSegments + 1]; | ||||
|         try (var fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ)) { | ||||
|             final long fileSize = fileChannel.size(); | ||||
|             final long segmentSize = (fileSize + numberOfSegments - 1) / numberOfSegments; | ||||
|             final long mappedAddress = fileChannel.map(MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); | ||||
|             chunks[0] = mappedAddress; | ||||
|             final long endAddress = mappedAddress + fileSize; | ||||
|             for (int i = 1; i < numberOfSegments; ++i) { | ||||
|                 long chunkAddress = mappedAddress + i * segmentSize; | ||||
|                 // Align to first row start. | ||||
|                 while (chunkAddress < endAddress && UNSAFE.getByte(chunkAddress++) != '\n') { | ||||
|                     // nop | ||||
|                 } | ||||
|                 chunks[i] = Math.min(chunkAddress, endAddress); | ||||
|             } | ||||
|             chunks[numberOfSegments] = endAddress; | ||||
|         } | ||||
|         return location; | ||||
|         return IntStream.range(0, chunks.length - 1) | ||||
|                 .mapToObj(chunkIndex -> new FileSegment(chunks[chunkIndex], chunks[chunkIndex + 1])) | ||||
|                 .parallel(); | ||||
|     } | ||||
|  | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user