diff --git a/calculate_average_flippingbits.sh b/calculate_average_flippingbits.sh index b37baa0..7dcbe74 100755 --- a/calculate_average_flippingbits.sh +++ b/calculate_average_flippingbits.sh @@ -15,5 +15,5 @@ # limitations under the License. # -JAVA_OPTS="--add-modules=jdk.incubator.vector" +JAVA_OPTS="--add-modules=jdk.incubator.vector --enable-preview" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_flippingbits diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java b/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java index 2510d85..3489877 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java @@ -18,8 +18,13 @@ package dev.morling.onebrc; import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.VectorOperators; +import sun.misc.Unsafe; +import java.lang.foreign.Arena; +import java.lang.reflect.Field; + import java.io.IOException; import java.io.RandomAccessFile; +import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.util.*; @@ -34,14 +39,31 @@ public class CalculateAverage_flippingbits { private static final String FILE = "./measurements.txt"; - private static final long CHUNK_SIZE = 10 * 1024 * 1024; // 10 MB + private static final long MINIMUM_FILE_SIZE_PARTITIONING = 10 * 1024 * 1024; // 10 MB private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length(); - private static final int MAX_STATION_NAME_LENGTH = 100; + private static final int NUM_STATIONS = 10_000; + + private static final int HASH_MAP_OFFSET_CAPACITY = 200_000; + + private static final Unsafe UNSAFE = initUnsafe(); + + private static int HASH_PRIME_NUMBER = 31; + + private static Unsafe initUnsafe() { + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + return (Unsafe) theUnsafe.get(Unsafe.class); + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } public static void main(String[] args) throws IOException { - var result = Arrays.asList(getSegments()).stream() + var result = Arrays.asList(getSegments()).parallelStream() .map(segment -> { try { return processSegment(segment[0], segment[1]); @@ -50,126 +72,137 @@ public class CalculateAverage_flippingbits { throw new RuntimeException(e); } }) - .parallel() - .reduce((firstMap, secondMap) -> { - for (var entry : secondMap.entrySet()) { - PartitionAggregate firstAggregate = firstMap.get(entry.getKey()); - if (firstAggregate == null) { - firstMap.put(entry.getKey(), entry.getValue()); - } - else { - firstAggregate.mergeWith(entry.getValue()); - } - } - return firstMap; - }) - .map(TreeMap::new).get(); + .reduce(FasterHashMap::mergeWith) + .get(); - System.out.println(result); + var sortedMap = new TreeMap(); + for (Station station : result.getEntries()) { + sortedMap.put(station.getName(), station); + } + + System.out.println(sortedMap); } private static long[][] getSegments() throws IOException { try (var file = new RandomAccessFile(FILE, "r")) { - var fileSize = file.length(); + var channel = file.getChannel(); + + var fileSize = channel.size(); + var startAddress = channel + .map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()) + .address(); + // Split file into segments, so we can work around the size limitation of channels - var numSegments = (int) (fileSize / CHUNK_SIZE); + var numSegments = (fileSize > MINIMUM_FILE_SIZE_PARTITIONING) + ? Runtime.getRuntime().availableProcessors() + : 1; + var segmentSize = fileSize / numSegments; - var boundaries = new long[numSegments + 1][2]; - var endPointer = 0L; + var boundaries = new long[numSegments][2]; + var endPointer = startAddress; - for (var i = 0; i < numSegments; i++) { + for (var i = 0; i < numSegments - 1; i++) { // Start of segment - boundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize); - - // Seek end of segment, limited by the end of the file - file.seek(Math.min(boundaries[i][0] + CHUNK_SIZE - 1, fileSize)); + boundaries[i][0] = endPointer; // Extend segment until end of line or file - while (file.read() != '\n') { + endPointer = endPointer + segmentSize; + while (UNSAFE.getByte(endPointer) != '\n') { + endPointer++; } // End of segment - endPointer = file.getFilePointer(); - boundaries[i][1] = endPointer; + boundaries[i][1] = endPointer++; } - boundaries[numSegments][0] = Math.max(endPointer, numSegments * CHUNK_SIZE); - boundaries[numSegments][1] = fileSize; + boundaries[numSegments - 1][0] = endPointer; + boundaries[numSegments - 1][1] = startAddress + fileSize; return boundaries; } } - private static Map processSegment(long startOfSegment, long endOfSegment) - throws IOException { - Map stationAggregates = new HashMap<>(50_000); - var byteChunk = new byte[(int) (endOfSegment - startOfSegment)]; - var stationBuffer = new byte[MAX_STATION_NAME_LENGTH]; - try (var file = new RandomAccessFile(FILE, "r")) { - file.seek(startOfSegment); - file.read(byteChunk); - var i = 0; - while (i < byteChunk.length) { - // Station name has at least one byte - stationBuffer[0] = byteChunk[i]; + private static FasterHashMap processSegment(long startOfSegment, long endOfSegment) throws IOException { + var fasterHashMap = new FasterHashMap(); + for (var i = startOfSegment; i < endOfSegment; i += 3) { + // Read station name + int nameHash = UNSAFE.getByte(i); + final var nameStartAddress = i++; + var character = UNSAFE.getByte(i); + while (character != ';') { + nameHash = nameHash * HASH_PRIME_NUMBER + character; i++; - // Read station name - var j = 1; - while (byteChunk[i] != ';') { - stationBuffer[j] = byteChunk[i]; - j++; - i++; - } - var station = new String(stationBuffer, 0, j, StandardCharsets.UTF_8); - i++; - - // Read measurement - var isNegative = byteChunk[i] == '-'; - var measurement = 0; - if (isNegative) { - i++; - while (byteChunk[i] != '.') { - measurement = measurement * 10 + byteChunk[i] - '0'; - i++; - } - measurement = (measurement * 10 + byteChunk[i + 1] - '0') * -1; - } - else { - while (byteChunk[i] != '.') { - measurement = measurement * 10 + byteChunk[i] - '0'; - i++; - } - measurement = measurement * 10 + byteChunk[i + 1] - '0'; - } - - // Update aggregate - var aggregate = stationAggregates.computeIfAbsent(station, x -> new PartitionAggregate()); - aggregate.addMeasurementAndComputeAggregate((short) measurement); - i += 3; + character = UNSAFE.getByte(i); } - stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements); + var nameLength = (int) (i - nameStartAddress); + i++; + + // Read measurement + var isNegative = UNSAFE.getByte(i) == '-'; + var measurement = 0; + if (isNegative) { + i++; + character = UNSAFE.getByte(i); + while (character != '.') { + measurement = measurement * 10 + character - '0'; + i++; + character = UNSAFE.getByte(i); + } + measurement = (measurement * 10 + UNSAFE.getByte(i + 1) - '0') * -1; + } + else { + character = UNSAFE.getByte(i); + while (character != '.') { + measurement = measurement * 10 + character - '0'; + i++; + character = UNSAFE.getByte(i); + } + measurement = measurement * 10 + UNSAFE.getByte(i + 1) - '0'; + } + + fasterHashMap.addEntry(nameHash, nameLength, nameStartAddress, (short) measurement); } - return stationAggregates; + for (Station station : fasterHashMap.getEntries()) { + station.aggregateRemainingMeasurements(); + } + + return fasterHashMap; } - private static class PartitionAggregate { - final short[] doubleLane = new short[SIMD_LANE_LENGTH * 2]; + private static class Station { + final short[] measurements = new short[SIMD_LANE_LENGTH * 2]; // Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition - int count = 0; + int count = 1; long sum = 0; short min = Short.MAX_VALUE; short max = Short.MIN_VALUE; + final long nameAddress; + final int nameLength; + final int nameHash; + + public Station(int nameHash, int nameLength, long nameAddress, short measurement) { + this.nameHash = nameHash; + this.nameLength = nameLength; + this.nameAddress = nameAddress; + measurements[0] = measurement; + } + + public String getName() { + byte[] name = new byte[nameLength]; + UNSAFE.copyMemory(null, nameAddress, name, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); + return new String(name, StandardCharsets.UTF_8); + } public void addMeasurementAndComputeAggregate(short measurement) { // Add measurement to buffer, which is later processed by SIMD instructions - doubleLane[count % doubleLane.length] = measurement; + measurements[count % measurements.length] = measurement; count++; // Once lane is full, use SIMD instructions to calculate aggregates - if (count % doubleLane.length == 0) { - var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, 0); - var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, SIMD_LANE_LENGTH); + if (count % measurements.length == 0) { + var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, measurements, 0); + var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, measurements, SIMD_LANE_LENGTH); var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN); min = (short) Math.min(min, simdMin); @@ -182,19 +215,35 @@ public class CalculateAverage_flippingbits { } public void aggregateRemainingMeasurements() { - for (var i = 0; i < count % doubleLane.length; i++) { - var measurement = doubleLane[i]; + for (var i = 0; i < count % measurements.length; i++) { + var measurement = measurements[i]; min = (short) Math.min(min, measurement); max = (short) Math.max(max, measurement); sum += measurement; } } - public void mergeWith(PartitionAggregate otherAggregate) { - min = (short) Math.min(min, otherAggregate.min); - max = (short) Math.max(max, otherAggregate.max); - count = count + otherAggregate.count; - sum = sum + otherAggregate.sum; + public void mergeWith(Station otherStation) { + min = (short) Math.min(min, otherStation.min); + max = (short) Math.max(max, otherStation.max); + count = count + otherStation.count; + sum = sum + otherStation.sum; + } + + public boolean nameEquals(long otherNameAddress) { + var swarLimit = (nameLength / Long.BYTES) * Long.BYTES; + var i = 0; + for (; i < swarLimit; i += Long.BYTES) { + if (UNSAFE.getLong(nameAddress + i) != UNSAFE.getLong(otherNameAddress + i)) { + return false; + } + } + for (; i < nameLength; i++) { + if (UNSAFE.getByte(nameAddress + i) != UNSAFE.getByte(otherNameAddress + i)) { + return false; + } + } + return true; } public String toString() { @@ -206,4 +255,67 @@ public class CalculateAverage_flippingbits { (max / 10.0)); } } + + /** + * Use two arrays for implementing the hash map: + * - The array `entries` holds the map values, in our case instances of the class Station. + * - The array `offsets` maps hashes of the keys to indexes in the `entries` array. + * + * We create `offsets` with a much larger capacity than `entries`, so we minimize collisions. + */ + private static class FasterHashMap { + // Using 16-bit integers (shorts) for offsets supports up to 2^15 (=32,767) entries + // If you need to store more entries, consider replacing short with int + short[] offsets = new short[HASH_MAP_OFFSET_CAPACITY]; + Station[] entries = new Station[NUM_STATIONS + 1]; + int slotsInUse = 0; + + private int getOffsetIdx(int nameHash, int nameLength, long nameAddress) { + var offsetIdx = nameHash & (offsets.length - 1); + var offset = offsets[offsetIdx]; + + while (offset != 0 && + (nameLength != entries[offset].nameLength || !entries[offset].nameEquals(nameAddress))) { + offsetIdx = (offsetIdx + 1) % offsets.length; + offset = offsets[offsetIdx]; + } + + return offsetIdx; + } + + public void addEntry(int nameHash, int nameLength, long nameAddress, short measurement) { + var offsetIdx = getOffsetIdx(nameHash, nameLength, nameAddress); + var offset = offsets[offsetIdx]; + + if (offset == 0) { + slotsInUse++; + entries[slotsInUse] = new Station(nameHash, nameLength, nameAddress, measurement); + offsets[offsetIdx] = (short) slotsInUse; + } + else { + entries[offset].addMeasurementAndComputeAggregate(measurement); + } + } + + public FasterHashMap mergeWith(FasterHashMap otherMap) { + for (Station station : otherMap.getEntries()) { + var offsetIdx = getOffsetIdx(station.nameHash, station.nameLength, station.nameAddress); + var offset = offsets[offsetIdx]; + + if (offset == 0) { + slotsInUse++; + entries[slotsInUse] = station; + offsets[offsetIdx] = (short) slotsInUse; + } + else { + entries[offset].mergeWith(station); + } + } + return this; + } + + public List getEntries() { + return Arrays.asList(entries).subList(1, slotsInUse + 1); + } + } }