2024-01-01 18:33:40 +01:00
|
|
|
/*
|
|
|
|
* Copyright 2023 The original authors
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
package dev.morling.onebrc;
|
|
|
|
|
|
|
|
import java.io.IOException;
|
2024-01-07 19:41:43 +01:00
|
|
|
import java.lang.foreign.Arena;
|
|
|
|
import java.lang.reflect.Field;
|
2024-01-03 20:44:24 +01:00
|
|
|
import java.nio.ByteOrder;
|
|
|
|
import java.nio.channels.FileChannel;
|
2024-01-01 18:33:40 +01:00
|
|
|
import java.nio.file.Path;
|
2024-01-03 20:44:24 +01:00
|
|
|
import java.nio.file.StandardOpenOption;
|
2024-01-04 23:22:48 +01:00
|
|
|
import java.util.Arrays;
|
2024-01-05 16:38:40 +01:00
|
|
|
import java.util.Objects;
|
2024-01-03 20:44:24 +01:00
|
|
|
import java.util.TreeMap;
|
2024-01-01 18:33:40 +01:00
|
|
|
import java.util.stream.Collectors;
|
2024-01-07 19:41:43 +01:00
|
|
|
import java.util.stream.IntStream;
|
2024-01-05 16:38:40 +01:00
|
|
|
import java.util.stream.Stream;
|
2024-01-01 18:33:40 +01:00
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
import sun.misc.Unsafe;
|
|
|
|
|
2024-01-03 20:44:24 +01:00
|
|
|
/**
|
|
|
|
* Changelog:
|
|
|
|
*
|
|
|
|
* Initial submission: 62000 ms
|
|
|
|
* Chunked reader: 16000 ms
|
|
|
|
* Optimized parser: 13000 ms
|
|
|
|
* Branchless methods: 11000 ms
|
|
|
|
* Adding memory mapped files: 6500 ms (based on bjhara's submission)
|
|
|
|
* Skipping string creation: 4700 ms
|
|
|
|
* Custom hashmap... 4200 ms
|
|
|
|
* Added SWAR token checks: 3900 ms
|
|
|
|
* Skipped String creation: 3500 ms (idea from kgonia)
|
|
|
|
* Improved String skip: 3250 ms
|
|
|
|
* Segmenting files: 3150 ms (based on spullara's code)
|
|
|
|
* Not using SWAR for EOL: 2850 ms
|
2024-01-04 23:22:48 +01:00
|
|
|
* Inlining hash calculation: 2450 ms
|
2024-01-05 16:38:40 +01:00
|
|
|
* Replacing branchless code: 2200 ms (sometimes we need to kill the things we love)
|
2024-01-07 19:41:43 +01:00
|
|
|
* Added unsafe memory access: 1900 ms (keeping the long[] small and local)
|
2024-01-03 20:44:24 +01:00
|
|
|
*
|
|
|
|
* Best performing JVM on MacBook M2 Pro: 21.0.1-graal
|
|
|
|
* `sdk use java 21.0.1-graal`
|
|
|
|
*
|
|
|
|
*/
|
2024-01-01 18:33:40 +01:00
|
|
|
public class CalculateAverage_royvanrijn {
|
|
|
|
|
|
|
|
private static final String FILE = "./measurements.txt";
|
2024-01-03 20:44:24 +01:00
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
private static final Unsafe UNSAFE = initUnsafe();
|
|
|
|
private static final boolean isBigEndian = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN);
|
2024-01-01 18:33:40 +01:00
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
private static Unsafe initUnsafe() {
|
|
|
|
try {
|
|
|
|
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
|
|
|
|
theUnsafe.setAccessible(true);
|
|
|
|
return (Unsafe) theUnsafe.get(Unsafe.class);
|
2024-01-01 18:33:40 +01:00
|
|
|
}
|
2024-01-07 19:41:43 +01:00
|
|
|
catch (NoSuchFieldException | IllegalAccessException e) {
|
|
|
|
throw new RuntimeException(e);
|
2024-01-01 18:33:40 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-01-04 23:22:48 +01:00
|
|
|
public static void main(String[] args) throws Exception {
|
2024-01-03 20:44:24 +01:00
|
|
|
new CalculateAverage_royvanrijn().run();
|
2024-01-04 23:22:48 +01:00
|
|
|
}
|
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
public void run() throws Exception {
|
|
|
|
|
|
|
|
// Calculate input segments.
|
|
|
|
int numberOfChunks = Runtime.getRuntime().availableProcessors();
|
|
|
|
long[] chunks = getSegments(numberOfChunks);
|
|
|
|
|
|
|
|
// Parallel processing of segments.
|
|
|
|
TreeMap<String, Measurement> results = IntStream.range(0, chunks.length - 1)
|
|
|
|
.mapToObj(chunkIndex -> process(chunks[chunkIndex], chunks[chunkIndex + 1])).parallel()
|
|
|
|
.flatMap(MeasurementRepository::get)
|
|
|
|
.collect(Collectors.toMap(e -> e.city, MeasurementRepository.Entry::measurement, Measurement::updateWith, TreeMap::new));
|
|
|
|
|
|
|
|
System.out.println(results);
|
|
|
|
}
|
|
|
|
|
|
|
|
private static long[] getSegments(int numberOfChunks) throws IOException {
|
|
|
|
try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
|
|
|
|
long fileSize = fileChannel.size();
|
|
|
|
long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks;
|
|
|
|
long[] chunks = new long[numberOfChunks + 1];
|
|
|
|
long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
|
|
|
|
chunks[0] = mappedAddress;
|
|
|
|
long endAddress = mappedAddress + fileSize;
|
|
|
|
for (int i = 1; i < numberOfChunks; ++i) {
|
|
|
|
long chunkAddress = mappedAddress + i * segmentSize;
|
|
|
|
// Align to first row start.
|
|
|
|
while (chunkAddress < endAddress && UNSAFE.getByte(chunkAddress++) != '\n') {
|
|
|
|
// nop
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
2024-01-07 19:41:43 +01:00
|
|
|
chunks[i] = Math.min(chunkAddress, endAddress);
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
2024-01-07 19:41:43 +01:00
|
|
|
chunks[numberOfChunks] = endAddress;
|
|
|
|
return chunks;
|
|
|
|
}
|
|
|
|
}
|
2024-01-03 20:44:24 +01:00
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
private MeasurementRepository process(long fromAddress, long toAddress) {
|
2024-01-05 16:38:40 +01:00
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
MeasurementRepository repository = new MeasurementRepository();
|
|
|
|
long ptr = fromAddress;
|
|
|
|
long[] dataBuffer = new long[16];
|
|
|
|
while ((ptr = processEntity(dataBuffer, ptr, toAddress, repository)) < toAddress)
|
|
|
|
;
|
|
|
|
|
|
|
|
return repository;
|
2024-01-01 18:33:40 +01:00
|
|
|
}
|
2024-01-03 20:44:24 +01:00
|
|
|
|
|
|
|
private static final long SEPARATOR_PATTERN = compilePattern((byte) ';');
|
|
|
|
|
2024-01-04 23:22:48 +01:00
|
|
|
/**
|
|
|
|
* Already looping the longs here, lets shoehorn in making a hash
|
|
|
|
*/
|
2024-01-07 19:41:43 +01:00
|
|
|
private long processEntity(final long[] data, final long start, final long limit, final MeasurementRepository measurementRepository) {
|
2024-01-04 23:22:48 +01:00
|
|
|
int hash = 1;
|
2024-01-07 19:41:43 +01:00
|
|
|
long i;
|
|
|
|
int dataPtr = 0;
|
2024-01-03 20:44:24 +01:00
|
|
|
for (i = start; i <= limit - 8; i += 8) {
|
2024-01-07 19:41:43 +01:00
|
|
|
long word = UNSAFE.getLong(i);
|
|
|
|
if (isBigEndian) {
|
2024-01-04 23:22:48 +01:00
|
|
|
word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this
|
2024-01-05 16:38:40 +01:00
|
|
|
}
|
2024-01-07 19:41:43 +01:00
|
|
|
final long match = word ^ SEPARATOR_PATTERN;
|
2024-01-05 16:38:40 +01:00
|
|
|
long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L;
|
|
|
|
|
|
|
|
if (mask != 0) {
|
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
final long partialWord = word & ((mask >> 7) - 1);
|
|
|
|
hash = longHashStep(hash, partialWord);
|
|
|
|
data[dataPtr] = partialWord;
|
|
|
|
|
|
|
|
final int index = Long.numberOfTrailingZeros(mask) >> 3;
|
|
|
|
return process(start, i + index, hash, data, measurementRepository);
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
2024-01-07 19:41:43 +01:00
|
|
|
data[dataPtr++] = word;
|
2024-01-05 16:38:40 +01:00
|
|
|
hash = longHashStep(hash, word);
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
2024-01-05 16:38:40 +01:00
|
|
|
// Handle remaining bytes near the limit of the buffer:
|
2024-01-07 19:41:43 +01:00
|
|
|
long partialWord = 0;
|
2024-01-05 16:38:40 +01:00
|
|
|
int len = 0;
|
2024-01-03 20:44:24 +01:00
|
|
|
for (; i < limit; i++) {
|
2024-01-04 23:22:48 +01:00
|
|
|
byte read;
|
2024-01-07 19:41:43 +01:00
|
|
|
if ((read = UNSAFE.getByte(i)) == ';') {
|
|
|
|
hash = longHashStep(hash, partialWord);
|
|
|
|
data[dataPtr] = partialWord;
|
|
|
|
return process(start, i, hash, data, measurementRepository);
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
2024-01-07 19:41:43 +01:00
|
|
|
partialWord = partialWord | ((long) read << (len << 3));
|
2024-01-05 16:38:40 +01:00
|
|
|
len++;
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
2024-01-07 19:41:43 +01:00
|
|
|
return limit;
|
2024-01-05 16:38:40 +01:00
|
|
|
}
|
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
private static final long DOT_BITS = 0x10101000;
|
|
|
|
private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
|
2024-01-03 20:44:24 +01:00
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
private long process(final long startAddress, final long delimiterAddress, final int hash, final long[] data, final MeasurementRepository measurementRepository) {
|
2024-01-03 20:44:24 +01:00
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
long word = UNSAFE.getLong(delimiterAddress + 1);
|
|
|
|
if (isBigEndian) {
|
|
|
|
word = Long.reverseBytes(word);
|
|
|
|
}
|
|
|
|
final long invWord = ~word;
|
|
|
|
final int decimalSepPos = Long.numberOfTrailingZeros(invWord & DOT_BITS);
|
|
|
|
final long signed = (invWord << 59) >> 63;
|
|
|
|
final long designMask = ~(signed & 0xFF);
|
|
|
|
final long digits = ((word & designMask) << (28 - decimalSepPos)) & 0x0F000F0F00L;
|
|
|
|
final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
|
|
|
|
final int measurement = (int) ((absValue ^ signed) - signed);
|
|
|
|
|
|
|
|
// Store:
|
|
|
|
measurementRepository.update(startAddress, data, (int) (delimiterAddress - startAddress), hash, measurement);
|
|
|
|
|
|
|
|
return delimiterAddress + (decimalSepPos >> 3) + 4; // Determine next start:
|
|
|
|
// return nextAddress;
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
static final class Measurement {
|
|
|
|
int min, max, count;
|
|
|
|
long sum;
|
|
|
|
|
|
|
|
public Measurement() {
|
|
|
|
this.min = 1000;
|
|
|
|
this.max = -1000;
|
2024-01-04 23:22:48 +01:00
|
|
|
}
|
2024-01-07 19:41:43 +01:00
|
|
|
|
|
|
|
public Measurement updateWith(int measurement) {
|
|
|
|
min = min(min, measurement);
|
|
|
|
max = max(max, measurement);
|
|
|
|
sum += measurement;
|
|
|
|
count++;
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
|
|
|
public Measurement updateWith(Measurement measurement) {
|
|
|
|
min = min(min, measurement.min);
|
|
|
|
max = max(max, measurement.max);
|
|
|
|
sum += measurement.sum;
|
|
|
|
count += measurement.count;
|
|
|
|
return this;
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
public String toString() {
|
|
|
|
return round(min) + "/" + round((1.0 * sum) / count) + "/" + round(max);
|
|
|
|
}
|
|
|
|
|
|
|
|
private double round(double value) {
|
|
|
|
return Math.round(value) / 10.0;
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// branchless max (unprecise for large numbers, but good enough)
|
|
|
|
static int max(final int a, final int b) {
|
|
|
|
final int diff = a - b;
|
|
|
|
final int dsgn = diff >> 31;
|
|
|
|
return a - (diff & dsgn);
|
|
|
|
}
|
|
|
|
|
|
|
|
// branchless min (unprecise for large numbers, but good enough)
|
|
|
|
static int min(final int a, final int b) {
|
|
|
|
final int diff = a - b;
|
|
|
|
final int dsgn = diff >> 31;
|
|
|
|
return b + (diff & dsgn);
|
|
|
|
}
|
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
private static int longHashStep(final int hash, final long word) {
|
|
|
|
return 31 * hash + (int) (word ^ (word >>> 32));
|
|
|
|
}
|
|
|
|
|
|
|
|
private static long compilePattern(final byte value) {
|
|
|
|
return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) |
|
|
|
|
((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value;
|
|
|
|
}
|
|
|
|
|
2024-01-03 20:44:24 +01:00
|
|
|
/**
|
|
|
|
* A normal Java HashMap does all these safety things like boundary checks... we don't need that, we need speeeed.
|
|
|
|
*
|
|
|
|
* So I've written an extremely simple linear probing hashmap that should work well enough.
|
|
|
|
*/
|
2024-01-04 23:22:48 +01:00
|
|
|
class MeasurementRepository {
|
2024-01-07 19:41:43 +01:00
|
|
|
private int tableSize = 1 << 20; // large enough for the contest.
|
2024-01-05 16:38:40 +01:00
|
|
|
private int tableMask = (tableSize - 1);
|
2024-01-04 23:22:48 +01:00
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
private MeasurementRepository.Entry[] table = new MeasurementRepository.Entry[tableSize];
|
|
|
|
|
|
|
|
record Entry(long address, long[] data, int length, int hash, String city, Measurement measurement) {
|
2024-01-03 20:44:24 +01:00
|
|
|
|
|
|
|
@Override
|
|
|
|
public String toString() {
|
2024-01-07 19:41:43 +01:00
|
|
|
return city + "=" + measurement;
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
public void update(long address, long[] data, int length, int hash, int temperature) {
|
2024-01-04 23:22:48 +01:00
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
int dataLength = length >> 3;
|
|
|
|
int index = hash & tableMask;
|
|
|
|
MeasurementRepository.Entry tableEntry;
|
2024-01-05 16:38:40 +01:00
|
|
|
while ((tableEntry = table[index]) != null
|
2024-01-07 19:41:43 +01:00
|
|
|
&& (tableEntry.hash != hash || tableEntry.length != length || !arrayEquals(tableEntry.data, data, dataLength))) { // search for the right spot
|
2024-01-05 16:38:40 +01:00
|
|
|
index = (index + 1) & tableMask;
|
|
|
|
}
|
2024-01-04 23:22:48 +01:00
|
|
|
|
2024-01-05 16:38:40 +01:00
|
|
|
if (tableEntry != null) {
|
2024-01-07 19:41:43 +01:00
|
|
|
tableEntry.measurement.updateWith(temperature);
|
|
|
|
return;
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
2024-01-04 23:22:48 +01:00
|
|
|
|
2024-01-05 16:38:40 +01:00
|
|
|
// --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here.
|
|
|
|
Measurement measurement = new Measurement();
|
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
byte[] bytes = new byte[length];
|
|
|
|
for (int i = 0; i < length; i++) {
|
|
|
|
bytes[i] = UNSAFE.getByte(address + i);
|
|
|
|
}
|
|
|
|
String city = new String(bytes);
|
2024-01-05 16:38:40 +01:00
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
long[] dataCopy = new long[dataLength];
|
|
|
|
System.arraycopy(data, 0, dataCopy, 0, dataLength);
|
2024-01-05 16:38:40 +01:00
|
|
|
|
|
|
|
// And add entry:
|
2024-01-07 19:41:43 +01:00
|
|
|
MeasurementRepository.Entry toAdd = new MeasurementRepository.Entry(address, dataCopy, length, hash, city, measurement);
|
2024-01-05 16:38:40 +01:00
|
|
|
table[index] = toAdd;
|
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
toAdd.measurement.updateWith(temperature);
|
2024-01-05 16:38:40 +01:00
|
|
|
}
|
|
|
|
|
2024-01-07 19:41:43 +01:00
|
|
|
public Stream<MeasurementRepository.Entry> get() {
|
2024-01-05 16:38:40 +01:00
|
|
|
return Arrays.stream(table).filter(Objects::nonNull);
|
|
|
|
}
|
2024-01-04 23:22:48 +01:00
|
|
|
}
|
2024-01-03 20:44:24 +01:00
|
|
|
|
2024-01-04 23:22:48 +01:00
|
|
|
/**
|
|
|
|
* For case multiple hashes are equal (however unlikely) check the actual key (using longs)
|
|
|
|
*/
|
|
|
|
private boolean arrayEquals(final long[] a, final long[] b, final int length) {
|
|
|
|
for (int i = 0; i < length; i++) {
|
|
|
|
if (a[i] != b[i])
|
|
|
|
return false;
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
2024-01-04 23:22:48 +01:00
|
|
|
return true;
|
2024-01-03 20:44:24 +01:00
|
|
|
}
|
2024-01-05 16:38:40 +01:00
|
|
|
|
2024-01-01 18:33:40 +01:00
|
|
|
}
|