Update submission (#385)
* feat(flippingbits): Improve parsing of station names * chore(flippingbits): Remove obsolete import * feat(flippingbits): Use custom hash map * feat(flippingbits): Use UNSAFE * fix(flippingbits): Support very small files * chore(flippingbits): Few cleanups * chore(flippingbits): Align names * fix(flippingbits): Initialize hash with first byte * fix(flippingbits): Fix initialization of hash value
This commit is contained in:
parent
fc6fca4315
commit
3fbc4a2fa8
@ -15,5 +15,5 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
JAVA_OPTS="--add-modules=jdk.incubator.vector"
|
JAVA_OPTS="--add-modules=jdk.incubator.vector --enable-preview"
|
||||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_flippingbits
|
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_flippingbits
|
||||||
|
@ -18,8 +18,13 @@ package dev.morling.onebrc;
|
|||||||
import jdk.incubator.vector.ShortVector;
|
import jdk.incubator.vector.ShortVector;
|
||||||
import jdk.incubator.vector.VectorOperators;
|
import jdk.incubator.vector.VectorOperators;
|
||||||
|
|
||||||
|
import sun.misc.Unsafe;
|
||||||
|
import java.lang.foreign.Arena;
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.RandomAccessFile;
|
import java.io.RandomAccessFile;
|
||||||
|
import java.nio.channels.FileChannel;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@ -34,14 +39,31 @@ public class CalculateAverage_flippingbits {
|
|||||||
|
|
||||||
private static final String FILE = "./measurements.txt";
|
private static final String FILE = "./measurements.txt";
|
||||||
|
|
||||||
private static final long CHUNK_SIZE = 10 * 1024 * 1024; // 10 MB
|
private static final long MINIMUM_FILE_SIZE_PARTITIONING = 10 * 1024 * 1024; // 10 MB
|
||||||
|
|
||||||
private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length();
|
private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length();
|
||||||
|
|
||||||
private static final int MAX_STATION_NAME_LENGTH = 100;
|
private static final int NUM_STATIONS = 10_000;
|
||||||
|
|
||||||
|
private static final int HASH_MAP_OFFSET_CAPACITY = 200_000;
|
||||||
|
|
||||||
|
private static final Unsafe UNSAFE = initUnsafe();
|
||||||
|
|
||||||
|
private static int HASH_PRIME_NUMBER = 31;
|
||||||
|
|
||||||
|
private static Unsafe initUnsafe() {
|
||||||
|
try {
|
||||||
|
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
|
||||||
|
theUnsafe.setAccessible(true);
|
||||||
|
return (Unsafe) theUnsafe.get(Unsafe.class);
|
||||||
|
}
|
||||||
|
catch (NoSuchFieldException | IllegalAccessException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException {
|
public static void main(String[] args) throws IOException {
|
||||||
var result = Arrays.asList(getSegments()).stream()
|
var result = Arrays.asList(getSegments()).parallelStream()
|
||||||
.map(segment -> {
|
.map(segment -> {
|
||||||
try {
|
try {
|
||||||
return processSegment(segment[0], segment[1]);
|
return processSegment(segment[0], segment[1]);
|
||||||
@ -50,126 +72,137 @@ public class CalculateAverage_flippingbits {
|
|||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.parallel()
|
.reduce(FasterHashMap::mergeWith)
|
||||||
.reduce((firstMap, secondMap) -> {
|
.get();
|
||||||
for (var entry : secondMap.entrySet()) {
|
|
||||||
PartitionAggregate firstAggregate = firstMap.get(entry.getKey());
|
|
||||||
if (firstAggregate == null) {
|
|
||||||
firstMap.put(entry.getKey(), entry.getValue());
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
firstAggregate.mergeWith(entry.getValue());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return firstMap;
|
|
||||||
})
|
|
||||||
.map(TreeMap::new).get();
|
|
||||||
|
|
||||||
System.out.println(result);
|
var sortedMap = new TreeMap<String, Station>();
|
||||||
|
for (Station station : result.getEntries()) {
|
||||||
|
sortedMap.put(station.getName(), station);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println(sortedMap);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static long[][] getSegments() throws IOException {
|
private static long[][] getSegments() throws IOException {
|
||||||
try (var file = new RandomAccessFile(FILE, "r")) {
|
try (var file = new RandomAccessFile(FILE, "r")) {
|
||||||
var fileSize = file.length();
|
var channel = file.getChannel();
|
||||||
|
|
||||||
|
var fileSize = channel.size();
|
||||||
|
var startAddress = channel
|
||||||
|
.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global())
|
||||||
|
.address();
|
||||||
|
|
||||||
// Split file into segments, so we can work around the size limitation of channels
|
// Split file into segments, so we can work around the size limitation of channels
|
||||||
var numSegments = (int) (fileSize / CHUNK_SIZE);
|
var numSegments = (fileSize > MINIMUM_FILE_SIZE_PARTITIONING)
|
||||||
|
? Runtime.getRuntime().availableProcessors()
|
||||||
|
: 1;
|
||||||
|
var segmentSize = fileSize / numSegments;
|
||||||
|
|
||||||
var boundaries = new long[numSegments + 1][2];
|
var boundaries = new long[numSegments][2];
|
||||||
var endPointer = 0L;
|
var endPointer = startAddress;
|
||||||
|
|
||||||
for (var i = 0; i < numSegments; i++) {
|
for (var i = 0; i < numSegments - 1; i++) {
|
||||||
// Start of segment
|
// Start of segment
|
||||||
boundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize);
|
boundaries[i][0] = endPointer;
|
||||||
|
|
||||||
// Seek end of segment, limited by the end of the file
|
|
||||||
file.seek(Math.min(boundaries[i][0] + CHUNK_SIZE - 1, fileSize));
|
|
||||||
|
|
||||||
// Extend segment until end of line or file
|
// Extend segment until end of line or file
|
||||||
while (file.read() != '\n') {
|
endPointer = endPointer + segmentSize;
|
||||||
|
while (UNSAFE.getByte(endPointer) != '\n') {
|
||||||
|
endPointer++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// End of segment
|
// End of segment
|
||||||
endPointer = file.getFilePointer();
|
boundaries[i][1] = endPointer++;
|
||||||
boundaries[i][1] = endPointer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
boundaries[numSegments][0] = Math.max(endPointer, numSegments * CHUNK_SIZE);
|
boundaries[numSegments - 1][0] = endPointer;
|
||||||
boundaries[numSegments][1] = fileSize;
|
boundaries[numSegments - 1][1] = startAddress + fileSize;
|
||||||
|
|
||||||
return boundaries;
|
return boundaries;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, PartitionAggregate> processSegment(long startOfSegment, long endOfSegment)
|
private static FasterHashMap processSegment(long startOfSegment, long endOfSegment) throws IOException {
|
||||||
throws IOException {
|
var fasterHashMap = new FasterHashMap();
|
||||||
Map<String, PartitionAggregate> stationAggregates = new HashMap<>(50_000);
|
for (var i = startOfSegment; i < endOfSegment; i += 3) {
|
||||||
var byteChunk = new byte[(int) (endOfSegment - startOfSegment)];
|
|
||||||
var stationBuffer = new byte[MAX_STATION_NAME_LENGTH];
|
|
||||||
try (var file = new RandomAccessFile(FILE, "r")) {
|
|
||||||
file.seek(startOfSegment);
|
|
||||||
file.read(byteChunk);
|
|
||||||
var i = 0;
|
|
||||||
while (i < byteChunk.length) {
|
|
||||||
// Station name has at least one byte
|
|
||||||
stationBuffer[0] = byteChunk[i];
|
|
||||||
i++;
|
|
||||||
// Read station name
|
// Read station name
|
||||||
var j = 1;
|
int nameHash = UNSAFE.getByte(i);
|
||||||
while (byteChunk[i] != ';') {
|
final var nameStartAddress = i++;
|
||||||
stationBuffer[j] = byteChunk[i];
|
var character = UNSAFE.getByte(i);
|
||||||
j++;
|
while (character != ';') {
|
||||||
|
nameHash = nameHash * HASH_PRIME_NUMBER + character;
|
||||||
i++;
|
i++;
|
||||||
|
character = UNSAFE.getByte(i);
|
||||||
}
|
}
|
||||||
var station = new String(stationBuffer, 0, j, StandardCharsets.UTF_8);
|
var nameLength = (int) (i - nameStartAddress);
|
||||||
i++;
|
i++;
|
||||||
|
|
||||||
// Read measurement
|
// Read measurement
|
||||||
var isNegative = byteChunk[i] == '-';
|
var isNegative = UNSAFE.getByte(i) == '-';
|
||||||
var measurement = 0;
|
var measurement = 0;
|
||||||
if (isNegative) {
|
if (isNegative) {
|
||||||
i++;
|
i++;
|
||||||
while (byteChunk[i] != '.') {
|
character = UNSAFE.getByte(i);
|
||||||
measurement = measurement * 10 + byteChunk[i] - '0';
|
while (character != '.') {
|
||||||
|
measurement = measurement * 10 + character - '0';
|
||||||
i++;
|
i++;
|
||||||
|
character = UNSAFE.getByte(i);
|
||||||
}
|
}
|
||||||
measurement = (measurement * 10 + byteChunk[i + 1] - '0') * -1;
|
measurement = (measurement * 10 + UNSAFE.getByte(i + 1) - '0') * -1;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
while (byteChunk[i] != '.') {
|
character = UNSAFE.getByte(i);
|
||||||
measurement = measurement * 10 + byteChunk[i] - '0';
|
while (character != '.') {
|
||||||
|
measurement = measurement * 10 + character - '0';
|
||||||
i++;
|
i++;
|
||||||
|
character = UNSAFE.getByte(i);
|
||||||
}
|
}
|
||||||
measurement = measurement * 10 + byteChunk[i + 1] - '0';
|
measurement = measurement * 10 + UNSAFE.getByte(i + 1) - '0';
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update aggregate
|
fasterHashMap.addEntry(nameHash, nameLength, nameStartAddress, (short) measurement);
|
||||||
var aggregate = stationAggregates.computeIfAbsent(station, x -> new PartitionAggregate());
|
|
||||||
aggregate.addMeasurementAndComputeAggregate((short) measurement);
|
|
||||||
i += 3;
|
|
||||||
}
|
|
||||||
stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return stationAggregates;
|
for (Station station : fasterHashMap.getEntries()) {
|
||||||
|
station.aggregateRemainingMeasurements();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class PartitionAggregate {
|
return fasterHashMap;
|
||||||
final short[] doubleLane = new short[SIMD_LANE_LENGTH * 2];
|
}
|
||||||
|
|
||||||
|
private static class Station {
|
||||||
|
final short[] measurements = new short[SIMD_LANE_LENGTH * 2];
|
||||||
// Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition
|
// Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition
|
||||||
int count = 0;
|
int count = 1;
|
||||||
long sum = 0;
|
long sum = 0;
|
||||||
short min = Short.MAX_VALUE;
|
short min = Short.MAX_VALUE;
|
||||||
short max = Short.MIN_VALUE;
|
short max = Short.MIN_VALUE;
|
||||||
|
final long nameAddress;
|
||||||
|
final int nameLength;
|
||||||
|
final int nameHash;
|
||||||
|
|
||||||
|
public Station(int nameHash, int nameLength, long nameAddress, short measurement) {
|
||||||
|
this.nameHash = nameHash;
|
||||||
|
this.nameLength = nameLength;
|
||||||
|
this.nameAddress = nameAddress;
|
||||||
|
measurements[0] = measurement;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getName() {
|
||||||
|
byte[] name = new byte[nameLength];
|
||||||
|
UNSAFE.copyMemory(null, nameAddress, name, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength);
|
||||||
|
return new String(name, StandardCharsets.UTF_8);
|
||||||
|
}
|
||||||
|
|
||||||
public void addMeasurementAndComputeAggregate(short measurement) {
|
public void addMeasurementAndComputeAggregate(short measurement) {
|
||||||
// Add measurement to buffer, which is later processed by SIMD instructions
|
// Add measurement to buffer, which is later processed by SIMD instructions
|
||||||
doubleLane[count % doubleLane.length] = measurement;
|
measurements[count % measurements.length] = measurement;
|
||||||
count++;
|
count++;
|
||||||
|
|
||||||
// Once lane is full, use SIMD instructions to calculate aggregates
|
// Once lane is full, use SIMD instructions to calculate aggregates
|
||||||
if (count % doubleLane.length == 0) {
|
if (count % measurements.length == 0) {
|
||||||
var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, 0);
|
var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, measurements, 0);
|
||||||
var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, SIMD_LANE_LENGTH);
|
var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, measurements, SIMD_LANE_LENGTH);
|
||||||
|
|
||||||
var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN);
|
var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN);
|
||||||
min = (short) Math.min(min, simdMin);
|
min = (short) Math.min(min, simdMin);
|
||||||
@ -182,19 +215,35 @@ public class CalculateAverage_flippingbits {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void aggregateRemainingMeasurements() {
|
public void aggregateRemainingMeasurements() {
|
||||||
for (var i = 0; i < count % doubleLane.length; i++) {
|
for (var i = 0; i < count % measurements.length; i++) {
|
||||||
var measurement = doubleLane[i];
|
var measurement = measurements[i];
|
||||||
min = (short) Math.min(min, measurement);
|
min = (short) Math.min(min, measurement);
|
||||||
max = (short) Math.max(max, measurement);
|
max = (short) Math.max(max, measurement);
|
||||||
sum += measurement;
|
sum += measurement;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void mergeWith(PartitionAggregate otherAggregate) {
|
public void mergeWith(Station otherStation) {
|
||||||
min = (short) Math.min(min, otherAggregate.min);
|
min = (short) Math.min(min, otherStation.min);
|
||||||
max = (short) Math.max(max, otherAggregate.max);
|
max = (short) Math.max(max, otherStation.max);
|
||||||
count = count + otherAggregate.count;
|
count = count + otherStation.count;
|
||||||
sum = sum + otherAggregate.sum;
|
sum = sum + otherStation.sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean nameEquals(long otherNameAddress) {
|
||||||
|
var swarLimit = (nameLength / Long.BYTES) * Long.BYTES;
|
||||||
|
var i = 0;
|
||||||
|
for (; i < swarLimit; i += Long.BYTES) {
|
||||||
|
if (UNSAFE.getLong(nameAddress + i) != UNSAFE.getLong(otherNameAddress + i)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; i < nameLength; i++) {
|
||||||
|
if (UNSAFE.getByte(nameAddress + i) != UNSAFE.getByte(otherNameAddress + i)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String toString() {
|
public String toString() {
|
||||||
@ -206,4 +255,67 @@ public class CalculateAverage_flippingbits {
|
|||||||
(max / 10.0));
|
(max / 10.0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use two arrays for implementing the hash map:
|
||||||
|
* - The array `entries` holds the map values, in our case instances of the class Station.
|
||||||
|
* - The array `offsets` maps hashes of the keys to indexes in the `entries` array.
|
||||||
|
*
|
||||||
|
* We create `offsets` with a much larger capacity than `entries`, so we minimize collisions.
|
||||||
|
*/
|
||||||
|
private static class FasterHashMap {
|
||||||
|
// Using 16-bit integers (shorts) for offsets supports up to 2^15 (=32,767) entries
|
||||||
|
// If you need to store more entries, consider replacing short with int
|
||||||
|
short[] offsets = new short[HASH_MAP_OFFSET_CAPACITY];
|
||||||
|
Station[] entries = new Station[NUM_STATIONS + 1];
|
||||||
|
int slotsInUse = 0;
|
||||||
|
|
||||||
|
private int getOffsetIdx(int nameHash, int nameLength, long nameAddress) {
|
||||||
|
var offsetIdx = nameHash & (offsets.length - 1);
|
||||||
|
var offset = offsets[offsetIdx];
|
||||||
|
|
||||||
|
while (offset != 0 &&
|
||||||
|
(nameLength != entries[offset].nameLength || !entries[offset].nameEquals(nameAddress))) {
|
||||||
|
offsetIdx = (offsetIdx + 1) % offsets.length;
|
||||||
|
offset = offsets[offsetIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
return offsetIdx;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addEntry(int nameHash, int nameLength, long nameAddress, short measurement) {
|
||||||
|
var offsetIdx = getOffsetIdx(nameHash, nameLength, nameAddress);
|
||||||
|
var offset = offsets[offsetIdx];
|
||||||
|
|
||||||
|
if (offset == 0) {
|
||||||
|
slotsInUse++;
|
||||||
|
entries[slotsInUse] = new Station(nameHash, nameLength, nameAddress, measurement);
|
||||||
|
offsets[offsetIdx] = (short) slotsInUse;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
entries[offset].addMeasurementAndComputeAggregate(measurement);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public FasterHashMap mergeWith(FasterHashMap otherMap) {
|
||||||
|
for (Station station : otherMap.getEntries()) {
|
||||||
|
var offsetIdx = getOffsetIdx(station.nameHash, station.nameLength, station.nameAddress);
|
||||||
|
var offset = offsets[offsetIdx];
|
||||||
|
|
||||||
|
if (offset == 0) {
|
||||||
|
slotsInUse++;
|
||||||
|
entries[slotsInUse] = station;
|
||||||
|
offsets[offsetIdx] = (short) slotsInUse;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
entries[offset].mergeWith(station);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Station> getEntries() {
|
||||||
|
return Arrays.asList(entries).subList(1, slotsInUse + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user