From ba793e88cd3c1b7767e00180c721b85cf5c50e28 Mon Sep 17 00:00:00 2001 From: yourwass <157275797+yourwass@users.noreply.github.com> Date: Tue, 23 Jan 2024 21:04:55 +0200 Subject: [PATCH] Add Yourwass take on the challenge (#532) * Uses vector api for city name parsing and for hash index collision resolution * Uses lookup tables for temperature parsing --- calculate_average_yourwass.sh | 23 ++ .../onebrc/CalculateAverage_yourwass.java | 288 ++++++++++++++++++ 2 files changed, 311 insertions(+) create mode 100755 calculate_average_yourwass.sh create mode 100644 src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java diff --git a/calculate_average_yourwass.sh b/calculate_average_yourwass.sh new file mode 100755 index 0000000..07284ba --- /dev/null +++ b/calculate_average_yourwass.sh @@ -0,0 +1,23 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Uncomment below to use sdk +# source "$HOME/.sdkman/bin/sdkman-init.sh" +# sdk use java 21.0.1-graal 1>&2 + +JAVA_OPTS="--enable-preview --enable-native-access=ALL-UNNAMED --add-modules jdk.incubator.vector" +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yourwass diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java b/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java new file mode 100644 index 0000000..0a24b0a --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java @@ -0,0 +1,288 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import java.util.TreeMap; +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.reflect.Field; +import java.nio.channels.FileChannel; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.nio.charset.StandardCharsets; +import java.nio.ByteOrder; +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; +import sun.misc.Unsafe; + +public class CalculateAverage_yourwass { + + static final class Record { + public String city; + public long cityAddr; + public long cityLength; + public int min; + public int max; + public int count; + public long sum; + + Record(final long cityAddr, final long cityLength) { + this.city = null; + this.cityAddr = cityAddr; + this.cityLength = cityLength; + this.min = 1000; + this.max = -1000; + this.sum = 0; + this.count = 0; + } + + private Record merge(Record r) { + if (r.min < this.min) + this.min = r.min; + if (r.max > this.max) + this.max = r.max; + this.sum += r.sum; + this.count += r.count; + return this; + } + } + + private static short lookupDecimal[]; + private static byte lookupFraction[]; + private static byte lookupDotPositive[]; + private static byte lookupDotNegative[]; + private static MemorySegment VAS; + private static final VectorSpecies SPECIES = ByteVector.SPECIES_PREFERRED; + private static final int MAXINDEX = (1 << 16) + 10000; // short hash + max allowed cities for collisions at the end :p + private static final String FILE = "measurements.txt"; + private static final Unsafe UNSAFE = getUnsafe(); + + private static Unsafe getUnsafe() { + try { + final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + Unsafe unsafe = (Unsafe) theUnsafe.get(null); + return unsafe; + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) throws IOException, Throwable { + // prepare lookup tables + // the parsing reads two shorts after possible '-' + // first short, the Decimal part, can be N. or NN with N:[0..9] + // second short, the Fraction part, can be N\n or .N + lookupDecimal = new short[('9' << 8) + '9' + 1]; + lookupFraction = new byte[('9' << 8) + '.' + 1]; + lookupDotPositive = new byte[('9' << 8) + '.' + 1]; + lookupDotNegative = new byte[('9' << 8) + '.' + 1]; + for (short i = 0; i < 10; i++) { + final int ones = i * 10; + final int ix256 = i << 8; + // case N. i.e. single digit decimals: skip to 11824 = ('.'<<8)+'0' + lookupDecimal[11824 + i] = (short) ones; + for (short j = 1; j < 10; j++) { + // case NN i.e double digits decimals: skip to 12236 = ('0'<<8)+'0' + lookupDecimal[12336 + ix256 + j] = (short) (j * 100 + ones); + } + // case N\n skip to 2608 = ('\n'<<8)+'0' + lookupFraction[2608 + i] = (byte) i; + lookupDotPositive[2608 + i] = 4; + lookupDotNegative[2608 + i] = 5; + // case .N skip to 12334 = ('0'<<8)+'.' + lookupFraction[12334 + ix256] = (byte) i; + lookupDotPositive[12334 + ix256] = 5; + lookupDotNegative[12334 + ix256] = 6; + } + + // open file + final long fileSize, mmapAddr; + try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { + fileSize = fileChannel.size(); + mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); + } + // VAS: Virtual Address Space, as a MemorySegment upto and including the mmaped file. + // If the mmaped MemorySegment is used for Vector creation as is, then there are two problems: + // 1) fromMemorySegment takes an offset and not an address, so we have to do arithmetic + // this is solved by creating a MemorySegment from Address=0 + // 2) fromMemorySegment checks bounds for memory segment's size - Vector size + // this is solved by adding SPECIES.length() to the size of the segment, but + // XXX there lies the possibility for an out of bounds read at the end of file, which is not handled here. + VAS = MemorySegment.ofAddress(0).reinterpret(mmapAddr + fileSize + SPECIES.length()); + + // start and wait for threads to finish + final int nThreads = Runtime.getRuntime().availableProcessors(); + Thread[] threadList = new Thread[nThreads]; + final Record[][] results = new Record[nThreads][]; + final long chunkSize = fileSize / nThreads; + for (int i = 0; i < nThreads; i++) { + final int threadIndex = i; + final long startAddr = mmapAddr + i * chunkSize; + final long endAddr = (i == nThreads - 1) ? mmapAddr + fileSize : mmapAddr + (i + 1) * chunkSize; + threadList[i] = new Thread(() -> results[threadIndex] = threadMain(threadIndex, startAddr, endAddr, nThreads)); + threadList[i].start(); + } + for (int i = 0; i < nThreads; i++) + threadList[i].join(); + + // aggregate results and sort + // TODO have to compare with concurrent-parallel stream structures: + // * concurrent hashtable that have to sort afterwards + // * concurrent skiplist that is sorted but has O(n) insert + // * ..other? + final TreeMap aggregateResults = new TreeMap<>(); + for (int thread = 0; thread < nThreads; thread++) { + for (int index = 0; index < MAXINDEX; index++) { + Record record = results[thread][index]; + if (record == null) + continue; + aggregateResults.compute(record.city, (k, v) -> (v == null) ? record : v.merge(record)); + } + } + + // prepare string and print + StringBuilder sb = new StringBuilder(); + sb.append("{"); + for (var entry : aggregateResults.entrySet()) { + Record record = entry.getValue(); + float min = record.min; + min /= 10.f; + float max = record.max; + max /= 10.f; + double avg = Math.round((record.sum * 1.0) / record.count) / 10.; + sb.append(record.city).append("=").append(min).append("/").append(avg).append("/").append(max).append(", "); + } + int stringLength = sb.length(); + sb.setCharAt(stringLength - 2, '}'); + sb.setCharAt(stringLength - 1, '\n'); + System.out.print(sb.toString()); + } + + private static final boolean citiesDiffer(final long a, final long b, final long len) { + int part = 0; + for (; part < (len - 1) >> 3; part++) + if (UNSAFE.getLong(a + (part << 3)) != UNSAFE.getLong(b + (part << 3))) + return true; + if (((UNSAFE.getLong(a + (part << 3)) ^ (UNSAFE.getLong(b + (part << 3)))) << ((8 - (len & 7)) << 3)) != 0) + return true; + return false; + } + + private static Record[] threadMain(int id, long startAddr, long endAddr, long nThreads) { + // snap to newlines + if (id != 0) + while (UNSAFE.getByte(startAddr++) != '\n') + ; + if (id != nThreads - 1) + while (UNSAFE.getByte(endAddr++) != '\n') + ; + + final Record[] results = new Record[MAXINDEX]; + final long VECTORBYTESIZE = SPECIES.length(); + final ByteOrder BYTEORDER = ByteOrder.nativeOrder(); + final ByteVector delim = ByteVector.broadcast(SPECIES, ';'); + long nextCityAddr = startAddr; // XXX from these three variables, + long cityAddr = nextCityAddr; // only two are necessary, but if one + long ptr = 0; // is eliminated, on my pc the benchmark gets worse.. + while (nextCityAddr < endAddr) { + // parse city + long mask = ByteVector.fromMemorySegment(SPECIES, VAS, nextCityAddr + ptr, BYTEORDER) + .compare(VectorOperators.EQ, delim).toLong(); + if (mask == 0) { + ptr += VECTORBYTESIZE; + continue; + } + final long cityLength = ptr + Long.numberOfTrailingZeros(mask); + final long tempAddr = cityAddr + cityLength + 1; + + // compute hash table index + int index; + if (cityLength > 1) + index = (UNSAFE.getByte(cityAddr) // mix the first, + ^ (UNSAFE.getByte(cityAddr + 2) << 4) // the third (even if it is the delimiter ';') + ^ (UNSAFE.getByte(tempAddr - 2) << 8) // and the last two bytes of each city's name + ^ (UNSAFE.getByte(tempAddr - 3) << 12)) + & 0xFFFF; + else + index = (UNSAFE.getByte(cityAddr) << 8) & 0xFF00; + + // resolve collisions with linear probing + // use vector api here also, but only if city name fits in one vector length, for faster default case + Record record = results[index]; + if (cityLength <= VECTORBYTESIZE) { + ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER); + while (record != null) { + if (cityLength == record.cityLength) { + long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, record.cityAddr, BYTEORDER) + .compare(VectorOperators.EQ, parsed).toLong(); + if (Long.numberOfTrailingZeros(~sameMask) >= cityLength) + break; + } + record = results[++index]; + } + } + else { // slower normal case for city names with length > VECTORBYTESIZE + while (record != null && (cityLength != record.cityLength || citiesDiffer(record.cityAddr, cityAddr, cityLength))) + record = results[++index]; + } + + // add record for new keys + // TODO have to avoid memory allocations on hot path + if (record == null) { + results[index] = new Record(cityAddr, cityLength); + record = results[index]; + } + + // parse temp with lookup tables + int temp; + if (UNSAFE.getByte(tempAddr) == '-') { + temp = -lookupDecimal[UNSAFE.getShort(tempAddr + 1)] - lookupFraction[UNSAFE.getShort(tempAddr + 3)]; + nextCityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)]; + } + else { + temp = lookupDecimal[UNSAFE.getShort(tempAddr)] + lookupFraction[UNSAFE.getShort(tempAddr + 2)]; + nextCityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)]; + } + cityAddr = nextCityAddr; + ptr = 0; + + // merge record + if (temp < record.min) + record.min = temp; + if (temp > record.max) + record.max = temp; + record.sum += temp; + record.count += 1; + } + + // create strings from raw data + // TODO should avoid this copy + byte b[] = new byte[100]; + for (int i = 0; i < MAXINDEX; i++) { + Record r = results[i]; + if (r == null) + continue; + UNSAFE.copyMemory(null, r.cityAddr, b, Unsafe.ARRAY_BYTE_BASE_OFFSET, r.cityLength); + r.city = new String(b, 0, (int) r.cityLength, StandardCharsets.UTF_8); + } + return results; + } + +}