Update submission (#385)
* feat(flippingbits): Improve parsing of station names * chore(flippingbits): Remove obsolete import * feat(flippingbits): Use custom hash map * feat(flippingbits): Use UNSAFE * fix(flippingbits): Support very small files * chore(flippingbits): Few cleanups * chore(flippingbits): Align names * fix(flippingbits): Initialize hash with first byte * fix(flippingbits): Fix initialization of hash value
This commit is contained in:
		| @@ -15,5 +15,5 @@ | |||||||
| #  limitations under the License. | #  limitations under the License. | ||||||
| # | # | ||||||
|  |  | ||||||
| JAVA_OPTS="--add-modules=jdk.incubator.vector" | JAVA_OPTS="--add-modules=jdk.incubator.vector --enable-preview" | ||||||
| java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_flippingbits | java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_flippingbits | ||||||
|   | |||||||
| @@ -18,8 +18,13 @@ package dev.morling.onebrc; | |||||||
| import jdk.incubator.vector.ShortVector; | import jdk.incubator.vector.ShortVector; | ||||||
| import jdk.incubator.vector.VectorOperators; | import jdk.incubator.vector.VectorOperators; | ||||||
|  |  | ||||||
|  | import sun.misc.Unsafe; | ||||||
|  | import java.lang.foreign.Arena; | ||||||
|  | import java.lang.reflect.Field; | ||||||
|  |  | ||||||
| import java.io.IOException; | import java.io.IOException; | ||||||
| import java.io.RandomAccessFile; | import java.io.RandomAccessFile; | ||||||
|  | import java.nio.channels.FileChannel; | ||||||
| import java.nio.charset.StandardCharsets; | import java.nio.charset.StandardCharsets; | ||||||
| import java.util.*; | import java.util.*; | ||||||
|  |  | ||||||
| @@ -34,14 +39,31 @@ public class CalculateAverage_flippingbits { | |||||||
|  |  | ||||||
|     private static final String FILE = "./measurements.txt"; |     private static final String FILE = "./measurements.txt"; | ||||||
|  |  | ||||||
|     private static final long CHUNK_SIZE = 10 * 1024 * 1024; // 10 MB |     private static final long MINIMUM_FILE_SIZE_PARTITIONING = 10 * 1024 * 1024; // 10 MB | ||||||
|  |  | ||||||
|     private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length(); |     private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length(); | ||||||
|  |  | ||||||
|     private static final int MAX_STATION_NAME_LENGTH = 100; |     private static final int NUM_STATIONS = 10_000; | ||||||
|  |  | ||||||
|  |     private static final int HASH_MAP_OFFSET_CAPACITY = 200_000; | ||||||
|  |  | ||||||
|  |     private static final Unsafe UNSAFE = initUnsafe(); | ||||||
|  |  | ||||||
|  |     private static int HASH_PRIME_NUMBER = 31; | ||||||
|  |  | ||||||
|  |     private static Unsafe initUnsafe() { | ||||||
|  |         try { | ||||||
|  |             Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); | ||||||
|  |             theUnsafe.setAccessible(true); | ||||||
|  |             return (Unsafe) theUnsafe.get(Unsafe.class); | ||||||
|  |         } | ||||||
|  |         catch (NoSuchFieldException | IllegalAccessException e) { | ||||||
|  |             throw new RuntimeException(e); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     public static void main(String[] args) throws IOException { |     public static void main(String[] args) throws IOException { | ||||||
|         var result = Arrays.asList(getSegments()).stream() |         var result = Arrays.asList(getSegments()).parallelStream() | ||||||
|                 .map(segment -> { |                 .map(segment -> { | ||||||
|                     try { |                     try { | ||||||
|                         return processSegment(segment[0], segment[1]); |                         return processSegment(segment[0], segment[1]); | ||||||
| @@ -50,126 +72,137 @@ public class CalculateAverage_flippingbits { | |||||||
|                         throw new RuntimeException(e); |                         throw new RuntimeException(e); | ||||||
|                     } |                     } | ||||||
|                 }) |                 }) | ||||||
|                 .parallel() |                 .reduce(FasterHashMap::mergeWith) | ||||||
|                 .reduce((firstMap, secondMap) -> { |                 .get(); | ||||||
|                     for (var entry : secondMap.entrySet()) { |  | ||||||
|                         PartitionAggregate firstAggregate = firstMap.get(entry.getKey()); |  | ||||||
|                         if (firstAggregate == null) { |  | ||||||
|                             firstMap.put(entry.getKey(), entry.getValue()); |  | ||||||
|                         } |  | ||||||
|                         else { |  | ||||||
|                             firstAggregate.mergeWith(entry.getValue()); |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                     return firstMap; |  | ||||||
|                 }) |  | ||||||
|                 .map(TreeMap::new).get(); |  | ||||||
|  |  | ||||||
|         System.out.println(result); |         var sortedMap = new TreeMap<String, Station>(); | ||||||
|  |         for (Station station : result.getEntries()) { | ||||||
|  |             sortedMap.put(station.getName(), station); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         System.out.println(sortedMap); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     private static long[][] getSegments() throws IOException { |     private static long[][] getSegments() throws IOException { | ||||||
|         try (var file = new RandomAccessFile(FILE, "r")) { |         try (var file = new RandomAccessFile(FILE, "r")) { | ||||||
|             var fileSize = file.length(); |             var channel = file.getChannel(); | ||||||
|  |  | ||||||
|  |             var fileSize = channel.size(); | ||||||
|  |             var startAddress = channel | ||||||
|  |                     .map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()) | ||||||
|  |                     .address(); | ||||||
|  |  | ||||||
|             // Split file into segments, so we can work around the size limitation of channels |             // Split file into segments, so we can work around the size limitation of channels | ||||||
|             var numSegments = (int) (fileSize / CHUNK_SIZE); |             var numSegments = (fileSize > MINIMUM_FILE_SIZE_PARTITIONING) | ||||||
|  |                     ? Runtime.getRuntime().availableProcessors() | ||||||
|  |                     : 1; | ||||||
|  |             var segmentSize = fileSize / numSegments; | ||||||
|  |  | ||||||
|             var boundaries = new long[numSegments + 1][2]; |             var boundaries = new long[numSegments][2]; | ||||||
|             var endPointer = 0L; |             var endPointer = startAddress; | ||||||
|  |  | ||||||
|             for (var i = 0; i < numSegments; i++) { |             for (var i = 0; i < numSegments - 1; i++) { | ||||||
|                 // Start of segment |                 // Start of segment | ||||||
|                 boundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize); |                 boundaries[i][0] = endPointer; | ||||||
|  |  | ||||||
|                 // Seek end of segment, limited by the end of the file |  | ||||||
|                 file.seek(Math.min(boundaries[i][0] + CHUNK_SIZE - 1, fileSize)); |  | ||||||
|  |  | ||||||
|                 // Extend segment until end of line or file |                 // Extend segment until end of line or file | ||||||
|                 while (file.read() != '\n') { |                 endPointer = endPointer + segmentSize; | ||||||
|  |                 while (UNSAFE.getByte(endPointer) != '\n') { | ||||||
|  |                     endPointer++; | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 // End of segment |                 // End of segment | ||||||
|                 endPointer = file.getFilePointer(); |                 boundaries[i][1] = endPointer++; | ||||||
|                 boundaries[i][1] = endPointer; |  | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             boundaries[numSegments][0] = Math.max(endPointer, numSegments * CHUNK_SIZE); |             boundaries[numSegments - 1][0] = endPointer; | ||||||
|             boundaries[numSegments][1] = fileSize; |             boundaries[numSegments - 1][1] = startAddress + fileSize; | ||||||
|  |  | ||||||
|             return boundaries; |             return boundaries; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     private static Map<String, PartitionAggregate> processSegment(long startOfSegment, long endOfSegment) |     private static FasterHashMap processSegment(long startOfSegment, long endOfSegment) throws IOException { | ||||||
|             throws IOException { |         var fasterHashMap = new FasterHashMap(); | ||||||
|         Map<String, PartitionAggregate> stationAggregates = new HashMap<>(50_000); |         for (var i = startOfSegment; i < endOfSegment; i += 3) { | ||||||
|         var byteChunk = new byte[(int) (endOfSegment - startOfSegment)]; |  | ||||||
|         var stationBuffer = new byte[MAX_STATION_NAME_LENGTH]; |  | ||||||
|         try (var file = new RandomAccessFile(FILE, "r")) { |  | ||||||
|             file.seek(startOfSegment); |  | ||||||
|             file.read(byteChunk); |  | ||||||
|             var i = 0; |  | ||||||
|             while (i < byteChunk.length) { |  | ||||||
|                 // Station name has at least one byte |  | ||||||
|                 stationBuffer[0] = byteChunk[i]; |  | ||||||
|                 i++; |  | ||||||
|             // Read station name |             // Read station name | ||||||
|                 var j = 1; |             int nameHash = UNSAFE.getByte(i); | ||||||
|                 while (byteChunk[i] != ';') { |             final var nameStartAddress = i++; | ||||||
|                     stationBuffer[j] = byteChunk[i]; |             var character = UNSAFE.getByte(i); | ||||||
|                     j++; |             while (character != ';') { | ||||||
|  |                 nameHash = nameHash * HASH_PRIME_NUMBER + character; | ||||||
|                 i++; |                 i++; | ||||||
|  |                 character = UNSAFE.getByte(i); | ||||||
|             } |             } | ||||||
|                 var station = new String(stationBuffer, 0, j, StandardCharsets.UTF_8); |             var nameLength = (int) (i - nameStartAddress); | ||||||
|             i++; |             i++; | ||||||
|  |  | ||||||
|             // Read measurement |             // Read measurement | ||||||
|                 var isNegative = byteChunk[i] == '-'; |             var isNegative = UNSAFE.getByte(i) == '-'; | ||||||
|             var measurement = 0; |             var measurement = 0; | ||||||
|             if (isNegative) { |             if (isNegative) { | ||||||
|                 i++; |                 i++; | ||||||
|                     while (byteChunk[i] != '.') { |                 character = UNSAFE.getByte(i); | ||||||
|                         measurement = measurement * 10 + byteChunk[i] - '0'; |                 while (character != '.') { | ||||||
|  |                     measurement = measurement * 10 + character - '0'; | ||||||
|                     i++; |                     i++; | ||||||
|  |                     character = UNSAFE.getByte(i); | ||||||
|                 } |                 } | ||||||
|                     measurement = (measurement * 10 + byteChunk[i + 1] - '0') * -1; |                 measurement = (measurement * 10 + UNSAFE.getByte(i + 1) - '0') * -1; | ||||||
|             } |             } | ||||||
|             else { |             else { | ||||||
|                     while (byteChunk[i] != '.') { |                 character = UNSAFE.getByte(i); | ||||||
|                         measurement = measurement * 10 + byteChunk[i] - '0'; |                 while (character != '.') { | ||||||
|  |                     measurement = measurement * 10 + character - '0'; | ||||||
|                     i++; |                     i++; | ||||||
|  |                     character = UNSAFE.getByte(i); | ||||||
|                 } |                 } | ||||||
|                     measurement = measurement * 10 + byteChunk[i + 1] - '0'; |                 measurement = measurement * 10 + UNSAFE.getByte(i + 1) - '0'; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|                 // Update aggregate |             fasterHashMap.addEntry(nameHash, nameLength, nameStartAddress, (short) measurement); | ||||||
|                 var aggregate = stationAggregates.computeIfAbsent(station, x -> new PartitionAggregate()); |  | ||||||
|                 aggregate.addMeasurementAndComputeAggregate((short) measurement); |  | ||||||
|                 i += 3; |  | ||||||
|             } |  | ||||||
|             stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements); |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         return stationAggregates; |         for (Station station : fasterHashMap.getEntries()) { | ||||||
|  |             station.aggregateRemainingMeasurements(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|     private static class PartitionAggregate { |         return fasterHashMap; | ||||||
|         final short[] doubleLane = new short[SIMD_LANE_LENGTH * 2]; |     } | ||||||
|  |  | ||||||
|  |     private static class Station { | ||||||
|  |         final short[] measurements = new short[SIMD_LANE_LENGTH * 2]; | ||||||
|         // Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition |         // Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition | ||||||
|         int count = 0; |         int count = 1; | ||||||
|         long sum = 0; |         long sum = 0; | ||||||
|         short min = Short.MAX_VALUE; |         short min = Short.MAX_VALUE; | ||||||
|         short max = Short.MIN_VALUE; |         short max = Short.MIN_VALUE; | ||||||
|  |         final long nameAddress; | ||||||
|  |         final int nameLength; | ||||||
|  |         final int nameHash; | ||||||
|  |  | ||||||
|  |         public Station(int nameHash, int nameLength, long nameAddress, short measurement) { | ||||||
|  |             this.nameHash = nameHash; | ||||||
|  |             this.nameLength = nameLength; | ||||||
|  |             this.nameAddress = nameAddress; | ||||||
|  |             measurements[0] = measurement; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         public String getName() { | ||||||
|  |             byte[] name = new byte[nameLength]; | ||||||
|  |             UNSAFE.copyMemory(null, nameAddress, name, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); | ||||||
|  |             return new String(name, StandardCharsets.UTF_8); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         public void addMeasurementAndComputeAggregate(short measurement) { |         public void addMeasurementAndComputeAggregate(short measurement) { | ||||||
|             // Add measurement to buffer, which is later processed by SIMD instructions |             // Add measurement to buffer, which is later processed by SIMD instructions | ||||||
|             doubleLane[count % doubleLane.length] = measurement; |             measurements[count % measurements.length] = measurement; | ||||||
|             count++; |             count++; | ||||||
|  |  | ||||||
|             // Once lane is full, use SIMD instructions to calculate aggregates |             // Once lane is full, use SIMD instructions to calculate aggregates | ||||||
|             if (count % doubleLane.length == 0) { |             if (count % measurements.length == 0) { | ||||||
|                 var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, 0); |                 var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, measurements, 0); | ||||||
|                 var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, SIMD_LANE_LENGTH); |                 var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, measurements, SIMD_LANE_LENGTH); | ||||||
|  |  | ||||||
|                 var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN); |                 var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN); | ||||||
|                 min = (short) Math.min(min, simdMin); |                 min = (short) Math.min(min, simdMin); | ||||||
| @@ -182,19 +215,35 @@ public class CalculateAverage_flippingbits { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         public void aggregateRemainingMeasurements() { |         public void aggregateRemainingMeasurements() { | ||||||
|             for (var i = 0; i < count % doubleLane.length; i++) { |             for (var i = 0; i < count % measurements.length; i++) { | ||||||
|                 var measurement = doubleLane[i]; |                 var measurement = measurements[i]; | ||||||
|                 min = (short) Math.min(min, measurement); |                 min = (short) Math.min(min, measurement); | ||||||
|                 max = (short) Math.max(max, measurement); |                 max = (short) Math.max(max, measurement); | ||||||
|                 sum += measurement; |                 sum += measurement; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         public void mergeWith(PartitionAggregate otherAggregate) { |         public void mergeWith(Station otherStation) { | ||||||
|             min = (short) Math.min(min, otherAggregate.min); |             min = (short) Math.min(min, otherStation.min); | ||||||
|             max = (short) Math.max(max, otherAggregate.max); |             max = (short) Math.max(max, otherStation.max); | ||||||
|             count = count + otherAggregate.count; |             count = count + otherStation.count; | ||||||
|             sum = sum + otherAggregate.sum; |             sum = sum + otherStation.sum; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         public boolean nameEquals(long otherNameAddress) { | ||||||
|  |             var swarLimit = (nameLength / Long.BYTES) * Long.BYTES; | ||||||
|  |             var i = 0; | ||||||
|  |             for (; i < swarLimit; i += Long.BYTES) { | ||||||
|  |                 if (UNSAFE.getLong(nameAddress + i) != UNSAFE.getLong(otherNameAddress + i)) { | ||||||
|  |                     return false; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             for (; i < nameLength; i++) { | ||||||
|  |                 if (UNSAFE.getByte(nameAddress + i) != UNSAFE.getByte(otherNameAddress + i)) { | ||||||
|  |                     return false; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             return true; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         public String toString() { |         public String toString() { | ||||||
| @@ -206,4 +255,67 @@ public class CalculateAverage_flippingbits { | |||||||
|                     (max / 10.0)); |                     (max / 10.0)); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * Use two arrays for implementing the hash map: | ||||||
|  |      * - The array `entries` holds the map values, in our case instances of the class Station. | ||||||
|  |      * - The array `offsets` maps hashes of the keys to indexes in the `entries` array. | ||||||
|  |      * | ||||||
|  |      * We create `offsets` with a much larger capacity than `entries`, so we minimize collisions. | ||||||
|  |      */ | ||||||
|  |     private static class FasterHashMap { | ||||||
|  |         // Using 16-bit integers (shorts) for offsets supports up to 2^15 (=32,767) entries | ||||||
|  |         // If you need to store more entries, consider replacing short with int | ||||||
|  |         short[] offsets = new short[HASH_MAP_OFFSET_CAPACITY]; | ||||||
|  |         Station[] entries = new Station[NUM_STATIONS + 1]; | ||||||
|  |         int slotsInUse = 0; | ||||||
|  |  | ||||||
|  |         private int getOffsetIdx(int nameHash, int nameLength, long nameAddress) { | ||||||
|  |             var offsetIdx = nameHash & (offsets.length - 1); | ||||||
|  |             var offset = offsets[offsetIdx]; | ||||||
|  |  | ||||||
|  |             while (offset != 0 && | ||||||
|  |                     (nameLength != entries[offset].nameLength || !entries[offset].nameEquals(nameAddress))) { | ||||||
|  |                 offsetIdx = (offsetIdx + 1) % offsets.length; | ||||||
|  |                 offset = offsets[offsetIdx]; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             return offsetIdx; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         public void addEntry(int nameHash, int nameLength, long nameAddress, short measurement) { | ||||||
|  |             var offsetIdx = getOffsetIdx(nameHash, nameLength, nameAddress); | ||||||
|  |             var offset = offsets[offsetIdx]; | ||||||
|  |  | ||||||
|  |             if (offset == 0) { | ||||||
|  |                 slotsInUse++; | ||||||
|  |                 entries[slotsInUse] = new Station(nameHash, nameLength, nameAddress, measurement); | ||||||
|  |                 offsets[offsetIdx] = (short) slotsInUse; | ||||||
|  |             } | ||||||
|  |             else { | ||||||
|  |                 entries[offset].addMeasurementAndComputeAggregate(measurement); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         public FasterHashMap mergeWith(FasterHashMap otherMap) { | ||||||
|  |             for (Station station : otherMap.getEntries()) { | ||||||
|  |                 var offsetIdx = getOffsetIdx(station.nameHash, station.nameLength, station.nameAddress); | ||||||
|  |                 var offset = offsets[offsetIdx]; | ||||||
|  |  | ||||||
|  |                 if (offset == 0) { | ||||||
|  |                     slotsInUse++; | ||||||
|  |                     entries[slotsInUse] = station; | ||||||
|  |                     offsets[offsetIdx] = (short) slotsInUse; | ||||||
|  |                 } | ||||||
|  |                 else { | ||||||
|  |                     entries[offset].mergeWith(station); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             return this; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         public List<Station> getEntries() { | ||||||
|  |             return Arrays.asList(entries).subList(1, slotsInUse + 1); | ||||||
|  |         } | ||||||
|  |     } | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user