Entry into the contest, calculate_average_mtopolnik.sh (#246)

* calculate_average_mtopolnik

* short hash (just first 8 bytes of name)

* Remove unneeded checks

* Remove archiving classes

* 2x larger hashtable

* Add "set" to setters

* Simplify parsing temperature, remove newline search

* Reduce the size of the name slot

* Store name length and use to detect collision

* Reduce memory loads in parseTemperature

* Use short for min/max

* Extract constant for semicolon

* Fix script header

* Explicit bash shell in shebang

* Inline usage of broadcast semicolon

* Try vectorization

* Remove vectorization

* Go Unsafe

* Use SWAR temperature parsing by merykitty

* Inline some things

* Remove commented-out MemorySegment usage

* Inline namesMem.asSlice() invocation

* Try out JVM JIT flags

* Implement strcmp

* Remove unused instance variables

* Optimize hashing

* Put station name into hashtable

* Reorder method

* Remove usage of MemorySegment.getUtf8String

Replace with UNSAFE.copyMemory() and new String()

* Fix hashing bug

* Remove outdated comments

* Fix informative constants

* Use broadcastByte() more

* Improve method naming

* More hashing

* Revert more hashing

* Add commented-out code to hash 16 bytes

* Slight cleanup

* Align hashtable at cacheline boundary

* Add Graal Native image

* Revert Graal Native image

This reverts commit d916a42326d89bd1a841bbbecfae185adb8679d7.

* Simplify shell script (no SDK selection)

* Move a constant, zero out hashtable on start

* Better name comparison

* Add prepare_mtopolnik.sh

* Cleaner idiom in name comparison

* AND instead of MOD for hashtable indexing

* Improve word masking code

* Fix formatting

* Reduce memory loads

* Remove endianness checks

* Avoid hash == 0 problem

* Fix subtle bug

* MergeSort of parellel results

* Touch up perf

* Touch up perf

* Remove -Xmx256m

* Extract result printing method

* Print allocation details on OOME

* Single mmap

* Use global allocation arena
This commit is contained in:
Marko Topolnik 2024-01-11 20:02:14 +01:00 committed by GitHub
parent 8ec9ba861a
commit 95459f5640
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 570 additions and 0 deletions

21
calculate_average_mtopolnik.sh Executable file
View File

@ -0,0 +1,21 @@
#!/bin/bash
#
# Copyright 2023 The original authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_mtopolnik::recordMeasurementAndAdvanceCursor"
# -XX:InlineSmallCode=10000 -XX:-TieredCompilation -XX:CICompilerCount=2 -XX:CompileThreshold=1000\
java --enable-preview \
--class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_mtopolnik

19
prepare_mtopolnik.sh Executable file
View File

@ -0,0 +1,19 @@
#!/bin/sh
#
# Copyright 2023 The original authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.1-graal 1>&2

View File

@ -0,0 +1,530 @@
/*
* Copyright 2023 The original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.morling.onebrc;
import sun.misc.Unsafe;
import java.io.File;
import java.io.RandomAccessFile;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel.MapMode;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
public class CalculateAverage_mtopolnik {
private static final Unsafe UNSAFE = unsafe();
private static final int MAX_NAME_LEN = 100;
private static final int STATS_TABLE_SIZE = 1 << 16;
private static final int TABLE_INDEX_MASK = STATS_TABLE_SIZE - 1;
private static final String MEASUREMENTS_TXT = "measurements.txt";
private static final byte SEMICOLON = ';';
private static final long BROADCAST_SEMICOLON = broadcastByte(SEMICOLON);
// These two are just informative, I let the IDE calculate them for me
private static final long NATIVE_MEM_PER_THREAD = StatsAccessor.SIZEOF * STATS_TABLE_SIZE;
private static final long NATIVE_MEM_ON_8_THREADS = 8 * NATIVE_MEM_PER_THREAD;
private static Unsafe unsafe() {
try {
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
return (Unsafe) theUnsafe.get(Unsafe.class);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
static class StationStats implements Comparable<StationStats> {
String name;
long sum;
int count;
int min;
int max;
@Override
public String toString() {
return String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum / count) / 10.0, max / 10.0);
}
@Override
public boolean equals(Object that) {
return that.getClass() == StationStats.class && ((StationStats) that).name.equals(this.name);
}
@Override
public int compareTo(StationStats that) {
return name.compareTo(that.name);
}
}
public static void main(String[] args) throws Exception {
calculate();
}
static void calculate() throws Exception {
final File file = new File(MEASUREMENTS_TXT);
final long length = file.length();
final int chunkCount = Runtime.getRuntime().availableProcessors();
final var results = new StationStats[chunkCount][];
final var chunkStartOffsets = new long[chunkCount];
try (var raf = new RandomAccessFile(file, "r")) {
final var inputBase = raf.getChannel().map(MapMode.READ_ONLY, 0, length, Arena.global()).address();
for (int i = 1; i < chunkStartOffsets.length; i++) {
var start = length * i / chunkStartOffsets.length;
raf.seek(start);
while (raf.read() != (byte) '\n') {
}
start = raf.getFilePointer();
chunkStartOffsets[i] = start;
}
var threads = new Thread[chunkCount];
for (int i = 0; i < chunkCount; i++) {
final long chunkStart = chunkStartOffsets[i];
final long chunkLimit = (i + 1 < chunkCount) ? chunkStartOffsets[i + 1] : length;
threads[i] = new Thread(new ChunkProcessor(inputBase + chunkStart, inputBase + chunkLimit, results, i));
}
for (var thread : threads) {
thread.start();
}
for (var thread : threads) {
thread.join();
}
}
mergeSortAndPrint(results);
}
private static class ChunkProcessor implements Runnable {
private static final long NAMEBUF_SIZE = 2 * Long.BYTES;
private static final int CACHELINE_SIZE = 64;
private final long inputBase;
private final long inputSize;
private final StationStats[][] results;
private final int myIndex;
private StatsAccessor stats;
private long nameBufBase;
private long cursor;
ChunkProcessor(long chunkStart, long chunkLimit, StationStats[][] results, int myIndex) {
this.inputBase = chunkStart;
this.inputSize = chunkLimit - chunkStart;
this.results = results;
this.myIndex = myIndex;
}
@Override
public void run() {
try (Arena confinedArena = Arena.ofConfined()) {
long totalAllocated = 0;
String threadName = Thread.currentThread().getName();
long statsByteSize = STATS_TABLE_SIZE * StatsAccessor.SIZEOF;
var diagnosticString = String.format("Thread %s needs %,d bytes, managed to allocate before OOM: ",
threadName, statsByteSize + NAMEBUF_SIZE);
try {
stats = new StatsAccessor(confinedArena.allocate(statsByteSize, CACHELINE_SIZE));
totalAllocated = statsByteSize;
nameBufBase = confinedArena.allocate(NAMEBUF_SIZE).address();
}
catch (OutOfMemoryError e) {
System.err.print(diagnosticString);
System.err.println(totalAllocated);
throw e;
}
processChunk();
exportResults();
}
}
private void processChunk() {
while (cursor < inputSize) {
long word1;
long word2;
if (cursor + 2 * Long.BYTES <= inputSize) {
word1 = UNSAFE.getLong(inputBase + cursor);
word2 = UNSAFE.getLong(inputBase + cursor + Long.BYTES);
}
else {
UNSAFE.putLong(nameBufBase, 0);
UNSAFE.putLong(nameBufBase + Long.BYTES, 0);
UNSAFE.copyMemory(inputBase + cursor, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor));
word1 = UNSAFE.getLong(nameBufBase);
word2 = UNSAFE.getLong(nameBufBase + Long.BYTES);
}
long posOfSemicolon = posOfSemicolon(word1, word2);
word1 = maskWord(word1, posOfSemicolon - cursor);
word2 = maskWord(word2, posOfSemicolon - cursor - Long.BYTES);
long hash = hash(word1);
long namePos = cursor;
long nameLen = posOfSemicolon - cursor;
assert nameLen <= 100 : "nameLen > 100";
int temperature = parseTemperatureAndAdvanceCursor(posOfSemicolon);
updateStats(hash, namePos, nameLen, word1, word2, temperature);
}
}
private void updateStats(long hash, long namePos, long nameLen, long nameWord1, long nameWord2, int temperature) {
int tableIndex = (int) (hash & TABLE_INDEX_MASK);
while (true) {
stats.gotoIndex(tableIndex);
if (stats.hash() == hash && stats.nameLen() == nameLen
&& nameEquals(stats.nameAddress(), inputBase + namePos, nameLen, nameWord1, nameWord2)) {
stats.setSum(stats.sum() + temperature);
stats.setCount(stats.count() + 1);
stats.setMin((short) Integer.min(stats.min(), temperature));
stats.setMax((short) Integer.max(stats.max(), temperature));
return;
}
if (stats.nameLen() != 0) {
tableIndex = (tableIndex + 1) & TABLE_INDEX_MASK;
continue;
}
stats.setHash(hash);
stats.setNameLen((int) nameLen);
stats.setSum(temperature);
stats.setCount(1);
stats.setMin((short) temperature);
stats.setMax((short) temperature);
UNSAFE.copyMemory(inputBase + namePos, stats.nameAddress(), nameLen);
return;
}
}
private int parseTemperatureAndAdvanceCursor(long semicolonPos) {
long startOffset = semicolonPos + 1;
if (startOffset <= inputSize - Long.BYTES) {
return parseTemperatureSwarAndAdvanceCursor(startOffset);
}
return parseTemperatureSimpleAndAdvanceCursor(startOffset);
}
// Credit: merykitty
private int parseTemperatureSwarAndAdvanceCursor(long startOffset) {
long word = UNSAFE.getLong(inputBase + startOffset);
final long negated = ~word;
final int dotPos = Long.numberOfTrailingZeros(negated & 0x10101000);
final long signed = (negated << 59) >> 63;
final long removeSignMask = ~(signed & 0xFF);
final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L;
final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
final int temperature = (int) ((absValue ^ signed) - signed);
cursor = startOffset + (dotPos / 8) + 3;
return temperature;
}
private int parseTemperatureSimpleAndAdvanceCursor(long startOffset) {
final byte minus = (byte) '-';
final byte zero = (byte) '0';
final byte dot = (byte) '.';
// Temperature plus the following newline is at least 4 chars, so this is always safe:
int fourCh = UNSAFE.getInt(inputBase + startOffset);
final int mask = 0xFF;
byte ch = (byte) (fourCh & mask);
int shift = 0;
int temperature;
int sign;
if (ch == minus) {
sign = -1;
shift += 8;
ch = (byte) ((fourCh & (mask << shift)) >>> shift);
}
else {
sign = 1;
}
temperature = ch - zero;
shift += 8;
ch = (byte) ((fourCh & (mask << shift)) >>> shift);
if (ch == dot) {
shift += 8;
ch = (byte) ((fourCh & (mask << shift)) >>> shift);
}
else {
temperature = 10 * temperature + (ch - zero);
shift += 16;
// The last character may be past the four loaded bytes, load it from memory.
// Checking that with another `if` is self-defeating for performance.
ch = UNSAFE.getByte(inputBase + startOffset + (shift / 8));
}
temperature = 10 * temperature + (ch - zero);
// `shift` holds the number of bits in the temperature field.
// A newline character follows the temperature, and so we advance
// the cursor past the newline to the start of the next line.
cursor = startOffset + (shift / 8) + 2;
return sign * temperature;
}
private static long hash(long word1) {
long seed = 0x51_7c_c1_b7_27_22_0a_95L;
int rotDist = 17;
long hash = word1;
hash *= seed;
hash = Long.rotateLeft(hash, rotDist);
// hash ^= word2;
// hash *= seed;
// hash = Long.rotateLeft(hash, rotDist);
return hash;
}
private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2) {
boolean mismatch1 = maskWord(inputWord1, len) != UNSAFE.getLong(statsAddr);
boolean mismatch2 = maskWord(inputWord2, len - Long.BYTES) != UNSAFE.getLong(statsAddr + Long.BYTES);
if (mismatch1 | mismatch2) {
return false;
}
for (int i = 2 * Long.BYTES; i < len; i++) {
if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) {
return false;
}
}
return true;
}
private static long maskWord(long word, long len) {
long halfShiftDistance = Long.max(0, Long.BYTES - len) << 2;
long mask = (~0L >>> halfShiftDistance) >>> halfShiftDistance; // avoid Java trap of shiftDist % 64
return word & mask;
}
private static final long BROADCAST_0x01 = broadcastByte(0x01);
private static final long BROADCAST_0x80 = broadcastByte(0x80);
// Adapted from https://jameshfisher.com/2017/01/24/bitwise-check-for-zero-byte/
// and https://github.com/ashvardanian/StringZilla/blob/14e7a78edcc16b031c06b375aac1f66d8f19d45a/stringzilla/stringzilla.h#L139-L169
long posOfSemicolon(long word1, long word2) {
long diff = word1 ^ BROADCAST_SEMICOLON;
long matchBits1 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
diff = word2 ^ BROADCAST_SEMICOLON;
long matchBits2 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
if ((matchBits1 | matchBits2) != 0) {
int trailing1 = Long.numberOfTrailingZeros(matchBits1);
int match1IsNonZero = trailing1 & 63;
match1IsNonZero |= match1IsNonZero >>> 3;
match1IsNonZero |= match1IsNonZero >>> 1;
match1IsNonZero |= match1IsNonZero >>> 1;
// Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to
// raise the lowest bit in traling2 if trailing1 is nonzero. This forces
// trailing2 to be zero if trailing1 is non-zero.
int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63;
return cursor + ((trailing1 | trailing2) >> 3);
}
long offset = cursor + 2 * Long.BYTES;
for (; offset <= inputSize - Long.BYTES; offset += Long.BYTES) {
var block = UNSAFE.getLong(inputBase + offset);
diff = block ^ BROADCAST_SEMICOLON;
long matchBits = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
if (matchBits != 0) {
return offset + Long.numberOfTrailingZeros(matchBits) / 8;
}
}
return posOfSemicolonSimple(offset);
}
private long posOfSemicolonSimple(long offset) {
for (; offset < inputSize; offset++) {
if (UNSAFE.getByte(inputBase + offset) == SEMICOLON) {
return offset;
}
}
throw new RuntimeException("Semicolon not found");
}
// Copies the results from native memory to Java heap and puts them into the results array.
private void exportResults() {
var exportedStats = new ArrayList<StationStats>(10_000);
for (int i = 0; i < STATS_TABLE_SIZE; i++) {
stats.gotoIndex(i);
if (stats.nameLen() == 0) {
continue;
}
var sum = stats.sum();
var count = stats.count();
var min = stats.min();
var max = stats.max();
var name = stats.exportNameString();
var stationStats = new StationStats();
stationStats.name = name;
stationStats.sum = sum;
stationStats.count = count;
stationStats.min = min;
stationStats.max = max;
exportedStats.add(stationStats);
}
StationStats[] exported = exportedStats.toArray(new StationStats[0]);
Arrays.sort(exported);
results[myIndex] = exported;
}
private final ByteBuffer buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder());
private String longToString(long word) {
buf.clear();
buf.putLong(word);
return new String(buf.array(), StandardCharsets.UTF_8); // + "|" + Arrays.toString(buf.array());
}
}
private static long broadcastByte(int b) {
long nnnnnnnn = b;
nnnnnnnn |= nnnnnnnn << 8;
nnnnnnnn |= nnnnnnnn << 16;
nnnnnnnn |= nnnnnnnn << 32;
return nnnnnnnn;
}
static class StatsAccessor {
static final int NAME_SLOT_SIZE = 104;
static final long HASH_OFFSET = 0;
static final long NAMELEN_OFFSET = HASH_OFFSET + Long.BYTES;
static final long SUM_OFFSET = NAMELEN_OFFSET + Integer.BYTES;
static final long COUNT_OFFSET = SUM_OFFSET + Integer.BYTES;
static final long MIN_OFFSET = COUNT_OFFSET + Integer.BYTES;
static final long MAX_OFFSET = MIN_OFFSET + Short.BYTES;
static final long NAME_OFFSET = MAX_OFFSET + Short.BYTES;
static final long SIZEOF = (NAME_OFFSET + NAME_SLOT_SIZE - 1) / 8 * 8 + 8;
static final int ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class);
private final long address;
private long slotBase;
StatsAccessor(MemorySegment memSeg) {
memSeg.fill((byte) 0);
this.address = memSeg.address();
}
void gotoIndex(int index) {
slotBase = address + index * SIZEOF;
}
long hash() {
return UNSAFE.getLong(slotBase + HASH_OFFSET);
}
int nameLen() {
return UNSAFE.getInt(slotBase + NAMELEN_OFFSET);
}
int sum() {
return UNSAFE.getInt(slotBase + SUM_OFFSET);
}
int count() {
return UNSAFE.getInt(slotBase + COUNT_OFFSET);
}
short min() {
return UNSAFE.getShort(slotBase + MIN_OFFSET);
}
short max() {
return UNSAFE.getShort(slotBase + MAX_OFFSET);
}
long nameAddress() {
return slotBase + NAME_OFFSET;
}
String exportNameString() {
final var bytes = new byte[nameLen()];
UNSAFE.copyMemory(null, nameAddress(), bytes, ARRAY_BASE_OFFSET, nameLen());
return new String(bytes, StandardCharsets.UTF_8);
}
void setHash(long hash) {
UNSAFE.putLong(slotBase + HASH_OFFSET, hash);
}
void setNameLen(int nameLen) {
UNSAFE.putInt(slotBase + NAMELEN_OFFSET, nameLen);
}
void setSum(int sum) {
UNSAFE.putInt(slotBase + SUM_OFFSET, sum);
}
void setCount(int count) {
UNSAFE.putInt(slotBase + COUNT_OFFSET, count);
}
void setMin(short min) {
UNSAFE.putShort(slotBase + MIN_OFFSET, min);
}
void setMax(short max) {
UNSAFE.putShort(slotBase + MAX_OFFSET, max);
}
}
private static void mergeSortAndPrint(StationStats[][] results) {
var onFirst = true;
System.out.print('{');
var cursors = new int[results.length];
var indexOfMin = 0;
StationStats curr = null;
int exhaustedCount;
while (true) {
exhaustedCount = 0;
StationStats min = null;
for (int i = 0; i < cursors.length; i++) {
if (cursors[i] == results[i].length) {
exhaustedCount++;
continue;
}
StationStats candidate = results[i][cursors[i]];
if (min == null || min.compareTo(candidate) > 0) {
indexOfMin = i;
min = candidate;
}
}
if (exhaustedCount == cursors.length) {
if (!onFirst) {
System.out.print(", ");
}
System.out.print(curr);
break;
}
cursors[indexOfMin]++;
if (curr == null) {
curr = min;
}
else if (min.equals(curr)) {
curr.sum += min.sum;
curr.count += min.count;
curr.min = Integer.min(curr.min, min.min);
curr.max = Integer.max(curr.max, min.max);
}
else {
if (onFirst) {
onFirst = false;
}
else {
System.out.print(", ");
}
System.out.print(curr);
curr = min;
}
}
System.out.println('}');
}
}