Update submission (#385)
* feat(flippingbits): Improve parsing of station names * chore(flippingbits): Remove obsolete import * feat(flippingbits): Use custom hash map * feat(flippingbits): Use UNSAFE * fix(flippingbits): Support very small files * chore(flippingbits): Few cleanups * chore(flippingbits): Align names * fix(flippingbits): Initialize hash with first byte * fix(flippingbits): Fix initialization of hash value
This commit is contained in:
		| @@ -15,5 +15,5 @@ | ||||
| #  limitations under the License. | ||||
| # | ||||
|  | ||||
| JAVA_OPTS="--add-modules=jdk.incubator.vector" | ||||
| JAVA_OPTS="--add-modules=jdk.incubator.vector --enable-preview" | ||||
| java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_flippingbits | ||||
|   | ||||
| @@ -18,8 +18,13 @@ package dev.morling.onebrc; | ||||
| import jdk.incubator.vector.ShortVector; | ||||
| import jdk.incubator.vector.VectorOperators; | ||||
|  | ||||
| import sun.misc.Unsafe; | ||||
| import java.lang.foreign.Arena; | ||||
| import java.lang.reflect.Field; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.io.RandomAccessFile; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.util.*; | ||||
|  | ||||
| @@ -34,14 +39,31 @@ public class CalculateAverage_flippingbits { | ||||
|  | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|  | ||||
|     private static final long CHUNK_SIZE = 10 * 1024 * 1024; // 10 MB | ||||
|     private static final long MINIMUM_FILE_SIZE_PARTITIONING = 10 * 1024 * 1024; // 10 MB | ||||
|  | ||||
|     private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length(); | ||||
|  | ||||
|     private static final int MAX_STATION_NAME_LENGTH = 100; | ||||
|     private static final int NUM_STATIONS = 10_000; | ||||
|  | ||||
|     private static final int HASH_MAP_OFFSET_CAPACITY = 200_000; | ||||
|  | ||||
|     private static final Unsafe UNSAFE = initUnsafe(); | ||||
|  | ||||
|     private static int HASH_PRIME_NUMBER = 31; | ||||
|  | ||||
|     private static Unsafe initUnsafe() { | ||||
|         try { | ||||
|             Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); | ||||
|             theUnsafe.setAccessible(true); | ||||
|             return (Unsafe) theUnsafe.get(Unsafe.class); | ||||
|         } | ||||
|         catch (NoSuchFieldException | IllegalAccessException e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     public static void main(String[] args) throws IOException { | ||||
|         var result = Arrays.asList(getSegments()).stream() | ||||
|         var result = Arrays.asList(getSegments()).parallelStream() | ||||
|                 .map(segment -> { | ||||
|                     try { | ||||
|                         return processSegment(segment[0], segment[1]); | ||||
| @@ -50,126 +72,137 @@ public class CalculateAverage_flippingbits { | ||||
|                         throw new RuntimeException(e); | ||||
|                     } | ||||
|                 }) | ||||
|                 .parallel() | ||||
|                 .reduce((firstMap, secondMap) -> { | ||||
|                     for (var entry : secondMap.entrySet()) { | ||||
|                         PartitionAggregate firstAggregate = firstMap.get(entry.getKey()); | ||||
|                         if (firstAggregate == null) { | ||||
|                             firstMap.put(entry.getKey(), entry.getValue()); | ||||
|                         } | ||||
|                         else { | ||||
|                             firstAggregate.mergeWith(entry.getValue()); | ||||
|                         } | ||||
|                     } | ||||
|                     return firstMap; | ||||
|                 }) | ||||
|                 .map(TreeMap::new).get(); | ||||
|                 .reduce(FasterHashMap::mergeWith) | ||||
|                 .get(); | ||||
|  | ||||
|         System.out.println(result); | ||||
|         var sortedMap = new TreeMap<String, Station>(); | ||||
|         for (Station station : result.getEntries()) { | ||||
|             sortedMap.put(station.getName(), station); | ||||
|         } | ||||
|  | ||||
|         System.out.println(sortedMap); | ||||
|     } | ||||
|  | ||||
|     private static long[][] getSegments() throws IOException { | ||||
|         try (var file = new RandomAccessFile(FILE, "r")) { | ||||
|             var fileSize = file.length(); | ||||
|             var channel = file.getChannel(); | ||||
|  | ||||
|             var fileSize = channel.size(); | ||||
|             var startAddress = channel | ||||
|                     .map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()) | ||||
|                     .address(); | ||||
|  | ||||
|             // Split file into segments, so we can work around the size limitation of channels | ||||
|             var numSegments = (int) (fileSize / CHUNK_SIZE); | ||||
|             var numSegments = (fileSize > MINIMUM_FILE_SIZE_PARTITIONING) | ||||
|                     ? Runtime.getRuntime().availableProcessors() | ||||
|                     : 1; | ||||
|             var segmentSize = fileSize / numSegments; | ||||
|  | ||||
|             var boundaries = new long[numSegments + 1][2]; | ||||
|             var endPointer = 0L; | ||||
|             var boundaries = new long[numSegments][2]; | ||||
|             var endPointer = startAddress; | ||||
|  | ||||
|             for (var i = 0; i < numSegments; i++) { | ||||
|             for (var i = 0; i < numSegments - 1; i++) { | ||||
|                 // Start of segment | ||||
|                 boundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize); | ||||
|  | ||||
|                 // Seek end of segment, limited by the end of the file | ||||
|                 file.seek(Math.min(boundaries[i][0] + CHUNK_SIZE - 1, fileSize)); | ||||
|                 boundaries[i][0] = endPointer; | ||||
|  | ||||
|                 // Extend segment until end of line or file | ||||
|                 while (file.read() != '\n') { | ||||
|                 endPointer = endPointer + segmentSize; | ||||
|                 while (UNSAFE.getByte(endPointer) != '\n') { | ||||
|                     endPointer++; | ||||
|                 } | ||||
|  | ||||
|                 // End of segment | ||||
|                 endPointer = file.getFilePointer(); | ||||
|                 boundaries[i][1] = endPointer; | ||||
|                 boundaries[i][1] = endPointer++; | ||||
|             } | ||||
|  | ||||
|             boundaries[numSegments][0] = Math.max(endPointer, numSegments * CHUNK_SIZE); | ||||
|             boundaries[numSegments][1] = fileSize; | ||||
|             boundaries[numSegments - 1][0] = endPointer; | ||||
|             boundaries[numSegments - 1][1] = startAddress + fileSize; | ||||
|  | ||||
|             return boundaries; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static Map<String, PartitionAggregate> processSegment(long startOfSegment, long endOfSegment) | ||||
|             throws IOException { | ||||
|         Map<String, PartitionAggregate> stationAggregates = new HashMap<>(50_000); | ||||
|         var byteChunk = new byte[(int) (endOfSegment - startOfSegment)]; | ||||
|         var stationBuffer = new byte[MAX_STATION_NAME_LENGTH]; | ||||
|         try (var file = new RandomAccessFile(FILE, "r")) { | ||||
|             file.seek(startOfSegment); | ||||
|             file.read(byteChunk); | ||||
|             var i = 0; | ||||
|             while (i < byteChunk.length) { | ||||
|                 // Station name has at least one byte | ||||
|                 stationBuffer[0] = byteChunk[i]; | ||||
|     private static FasterHashMap processSegment(long startOfSegment, long endOfSegment) throws IOException { | ||||
|         var fasterHashMap = new FasterHashMap(); | ||||
|         for (var i = startOfSegment; i < endOfSegment; i += 3) { | ||||
|             // Read station name | ||||
|             int nameHash = UNSAFE.getByte(i); | ||||
|             final var nameStartAddress = i++; | ||||
|             var character = UNSAFE.getByte(i); | ||||
|             while (character != ';') { | ||||
|                 nameHash = nameHash * HASH_PRIME_NUMBER + character; | ||||
|                 i++; | ||||
|                 // Read station name | ||||
|                 var j = 1; | ||||
|                 while (byteChunk[i] != ';') { | ||||
|                     stationBuffer[j] = byteChunk[i]; | ||||
|                     j++; | ||||
|                     i++; | ||||
|                 } | ||||
|                 var station = new String(stationBuffer, 0, j, StandardCharsets.UTF_8); | ||||
|                 i++; | ||||
|  | ||||
|                 // Read measurement | ||||
|                 var isNegative = byteChunk[i] == '-'; | ||||
|                 var measurement = 0; | ||||
|                 if (isNegative) { | ||||
|                     i++; | ||||
|                     while (byteChunk[i] != '.') { | ||||
|                         measurement = measurement * 10 + byteChunk[i] - '0'; | ||||
|                         i++; | ||||
|                     } | ||||
|                     measurement = (measurement * 10 + byteChunk[i + 1] - '0') * -1; | ||||
|                 } | ||||
|                 else { | ||||
|                     while (byteChunk[i] != '.') { | ||||
|                         measurement = measurement * 10 + byteChunk[i] - '0'; | ||||
|                         i++; | ||||
|                     } | ||||
|                     measurement = measurement * 10 + byteChunk[i + 1] - '0'; | ||||
|                 } | ||||
|  | ||||
|                 // Update aggregate | ||||
|                 var aggregate = stationAggregates.computeIfAbsent(station, x -> new PartitionAggregate()); | ||||
|                 aggregate.addMeasurementAndComputeAggregate((short) measurement); | ||||
|                 i += 3; | ||||
|                 character = UNSAFE.getByte(i); | ||||
|             } | ||||
|             stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements); | ||||
|             var nameLength = (int) (i - nameStartAddress); | ||||
|             i++; | ||||
|  | ||||
|             // Read measurement | ||||
|             var isNegative = UNSAFE.getByte(i) == '-'; | ||||
|             var measurement = 0; | ||||
|             if (isNegative) { | ||||
|                 i++; | ||||
|                 character = UNSAFE.getByte(i); | ||||
|                 while (character != '.') { | ||||
|                     measurement = measurement * 10 + character - '0'; | ||||
|                     i++; | ||||
|                     character = UNSAFE.getByte(i); | ||||
|                 } | ||||
|                 measurement = (measurement * 10 + UNSAFE.getByte(i + 1) - '0') * -1; | ||||
|             } | ||||
|             else { | ||||
|                 character = UNSAFE.getByte(i); | ||||
|                 while (character != '.') { | ||||
|                     measurement = measurement * 10 + character - '0'; | ||||
|                     i++; | ||||
|                     character = UNSAFE.getByte(i); | ||||
|                 } | ||||
|                 measurement = measurement * 10 + UNSAFE.getByte(i + 1) - '0'; | ||||
|             } | ||||
|  | ||||
|             fasterHashMap.addEntry(nameHash, nameLength, nameStartAddress, (short) measurement); | ||||
|         } | ||||
|  | ||||
|         return stationAggregates; | ||||
|         for (Station station : fasterHashMap.getEntries()) { | ||||
|             station.aggregateRemainingMeasurements(); | ||||
|         } | ||||
|  | ||||
|         return fasterHashMap; | ||||
|     } | ||||
|  | ||||
|     private static class PartitionAggregate { | ||||
|         final short[] doubleLane = new short[SIMD_LANE_LENGTH * 2]; | ||||
|     private static class Station { | ||||
|         final short[] measurements = new short[SIMD_LANE_LENGTH * 2]; | ||||
|         // Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition | ||||
|         int count = 0; | ||||
|         int count = 1; | ||||
|         long sum = 0; | ||||
|         short min = Short.MAX_VALUE; | ||||
|         short max = Short.MIN_VALUE; | ||||
|         final long nameAddress; | ||||
|         final int nameLength; | ||||
|         final int nameHash; | ||||
|  | ||||
|         public Station(int nameHash, int nameLength, long nameAddress, short measurement) { | ||||
|             this.nameHash = nameHash; | ||||
|             this.nameLength = nameLength; | ||||
|             this.nameAddress = nameAddress; | ||||
|             measurements[0] = measurement; | ||||
|         } | ||||
|  | ||||
|         public String getName() { | ||||
|             byte[] name = new byte[nameLength]; | ||||
|             UNSAFE.copyMemory(null, nameAddress, name, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); | ||||
|             return new String(name, StandardCharsets.UTF_8); | ||||
|         } | ||||
|  | ||||
|         public void addMeasurementAndComputeAggregate(short measurement) { | ||||
|             // Add measurement to buffer, which is later processed by SIMD instructions | ||||
|             doubleLane[count % doubleLane.length] = measurement; | ||||
|             measurements[count % measurements.length] = measurement; | ||||
|             count++; | ||||
|  | ||||
|             // Once lane is full, use SIMD instructions to calculate aggregates | ||||
|             if (count % doubleLane.length == 0) { | ||||
|                 var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, 0); | ||||
|                 var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, SIMD_LANE_LENGTH); | ||||
|             if (count % measurements.length == 0) { | ||||
|                 var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, measurements, 0); | ||||
|                 var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, measurements, SIMD_LANE_LENGTH); | ||||
|  | ||||
|                 var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN); | ||||
|                 min = (short) Math.min(min, simdMin); | ||||
| @@ -182,19 +215,35 @@ public class CalculateAverage_flippingbits { | ||||
|         } | ||||
|  | ||||
|         public void aggregateRemainingMeasurements() { | ||||
|             for (var i = 0; i < count % doubleLane.length; i++) { | ||||
|                 var measurement = doubleLane[i]; | ||||
|             for (var i = 0; i < count % measurements.length; i++) { | ||||
|                 var measurement = measurements[i]; | ||||
|                 min = (short) Math.min(min, measurement); | ||||
|                 max = (short) Math.max(max, measurement); | ||||
|                 sum += measurement; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public void mergeWith(PartitionAggregate otherAggregate) { | ||||
|             min = (short) Math.min(min, otherAggregate.min); | ||||
|             max = (short) Math.max(max, otherAggregate.max); | ||||
|             count = count + otherAggregate.count; | ||||
|             sum = sum + otherAggregate.sum; | ||||
|         public void mergeWith(Station otherStation) { | ||||
|             min = (short) Math.min(min, otherStation.min); | ||||
|             max = (short) Math.max(max, otherStation.max); | ||||
|             count = count + otherStation.count; | ||||
|             sum = sum + otherStation.sum; | ||||
|         } | ||||
|  | ||||
|         public boolean nameEquals(long otherNameAddress) { | ||||
|             var swarLimit = (nameLength / Long.BYTES) * Long.BYTES; | ||||
|             var i = 0; | ||||
|             for (; i < swarLimit; i += Long.BYTES) { | ||||
|                 if (UNSAFE.getLong(nameAddress + i) != UNSAFE.getLong(otherNameAddress + i)) { | ||||
|                     return false; | ||||
|                 } | ||||
|             } | ||||
|             for (; i < nameLength; i++) { | ||||
|                 if (UNSAFE.getByte(nameAddress + i) != UNSAFE.getByte(otherNameAddress + i)) { | ||||
|                     return false; | ||||
|                 } | ||||
|             } | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|         public String toString() { | ||||
| @@ -206,4 +255,67 @@ public class CalculateAverage_flippingbits { | ||||
|                     (max / 10.0)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Use two arrays for implementing the hash map: | ||||
|      * - The array `entries` holds the map values, in our case instances of the class Station. | ||||
|      * - The array `offsets` maps hashes of the keys to indexes in the `entries` array. | ||||
|      * | ||||
|      * We create `offsets` with a much larger capacity than `entries`, so we minimize collisions. | ||||
|      */ | ||||
|     private static class FasterHashMap { | ||||
|         // Using 16-bit integers (shorts) for offsets supports up to 2^15 (=32,767) entries | ||||
|         // If you need to store more entries, consider replacing short with int | ||||
|         short[] offsets = new short[HASH_MAP_OFFSET_CAPACITY]; | ||||
|         Station[] entries = new Station[NUM_STATIONS + 1]; | ||||
|         int slotsInUse = 0; | ||||
|  | ||||
|         private int getOffsetIdx(int nameHash, int nameLength, long nameAddress) { | ||||
|             var offsetIdx = nameHash & (offsets.length - 1); | ||||
|             var offset = offsets[offsetIdx]; | ||||
|  | ||||
|             while (offset != 0 && | ||||
|                     (nameLength != entries[offset].nameLength || !entries[offset].nameEquals(nameAddress))) { | ||||
|                 offsetIdx = (offsetIdx + 1) % offsets.length; | ||||
|                 offset = offsets[offsetIdx]; | ||||
|             } | ||||
|  | ||||
|             return offsetIdx; | ||||
|         } | ||||
|  | ||||
|         public void addEntry(int nameHash, int nameLength, long nameAddress, short measurement) { | ||||
|             var offsetIdx = getOffsetIdx(nameHash, nameLength, nameAddress); | ||||
|             var offset = offsets[offsetIdx]; | ||||
|  | ||||
|             if (offset == 0) { | ||||
|                 slotsInUse++; | ||||
|                 entries[slotsInUse] = new Station(nameHash, nameLength, nameAddress, measurement); | ||||
|                 offsets[offsetIdx] = (short) slotsInUse; | ||||
|             } | ||||
|             else { | ||||
|                 entries[offset].addMeasurementAndComputeAggregate(measurement); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public FasterHashMap mergeWith(FasterHashMap otherMap) { | ||||
|             for (Station station : otherMap.getEntries()) { | ||||
|                 var offsetIdx = getOffsetIdx(station.nameHash, station.nameLength, station.nameAddress); | ||||
|                 var offset = offsets[offsetIdx]; | ||||
|  | ||||
|                 if (offset == 0) { | ||||
|                     slotsInUse++; | ||||
|                     entries[slotsInUse] = station; | ||||
|                     offsets[offsetIdx] = (short) slotsInUse; | ||||
|                 } | ||||
|                 else { | ||||
|                     entries[offset].mergeWith(station); | ||||
|                 } | ||||
|             } | ||||
|             return this; | ||||
|         } | ||||
|  | ||||
|         public List<Station> getEntries() { | ||||
|             return Arrays.asList(entries).subList(1, slotsInUse + 1); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user