diff --git a/calculate_average_parkertimmins.sh b/calculate_average_parkertimmins.sh new file mode 100755 index 0000000..88c3871 --- /dev/null +++ b/calculate_average_parkertimmins.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +JAVA_OPTS="--enable-preview --add-modules jdk.incubator.vector" +time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_parkertimmins diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_parkertimmins.java b/src/main/java/dev/morling/onebrc/CalculateAverage_parkertimmins.java new file mode 100644 index 0000000..71412fb --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_parkertimmins.java @@ -0,0 +1,330 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.VectorMask; +import jdk.incubator.vector.VectorOperators; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; + +import java.lang.foreign.ValueLayout; +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.*; +import java.util.concurrent.atomic.AtomicLong; +import java.util.zip.CRC32C; + +public class CalculateAverage_parkertimmins { + private static final String FILE = "./measurements.txt"; + // private static final String FILE = "./full_measurements.no_license"; + + private static record ResultRow(double min, double mean, double max) { + public String toString() { + return round(min) + "/" + round(mean) + "/" + round(max); + } + + private double round(double value) { + return Math.round(value * 10.0) / 10.0; + } + }; + + static class OpenHashTable { + static class Entry { + byte[] key; + short min; + short max; + long sum = 0; + long count = 0; + int hash; + + void merge(OpenHashTable.Entry other) { + min = (short) Math.min(min, other.min); + max = (short) Math.max(max, other.max); + sum += other.sum; + count += other.count; + } + } + + static final int bits = 14; + static final int tableSize = 1 << bits; // 16k + static final int mask = tableSize - 1; + final Entry[] entries = new Entry[tableSize]; + + void add(byte[] buf, int sLen, short val, int hash) { + int idx = hash & mask; + + while (true) { + Entry entry = entries[idx]; + + // key not present, so add it + if (entry == null) { + entry = entries[idx] = new Entry(); + entry.key = Arrays.copyOf(buf, sLen); + entry.min = entry.max = val; + entry.sum += val; + entry.count++; + entry.hash = hash; + break; + } + else { + if (entry.hash == hash && entry.key.length == sLen && Arrays.equals(entry.key, 0, sLen, buf, 0, sLen)) { + entry.min = (short) Math.min(entry.min, val); + entry.max = (short) Math.max(entry.max, val); + entry.sum += val; + entry.count++; + break; + } + else { + idx = (idx + 1) & mask; + } + } + } + } + } + + static long findNextEntryStart(MemorySegment ms, long offset) { + long curr = offset; + while (ms.get(ValueLayout.JAVA_BYTE, curr) != '\n') { + curr++; + } + curr++; + return curr; + } + + static short[] digits10s = { 0, 100, 200, 300, 400, 500, 600, 700, 800, 900 }; + static short[] digits1s = { 0, 10, 20, 30, 40, 50, 60, 70, 80, 90 }; + + static void processRangeScalar(MemorySegment ms, long start, long end, final OpenHashTable localAgg) { + byte[] buf = new byte[128]; + + long curr = start; + long limit = end; + + while (curr < limit) { + int i = 0; + byte val = ms.get(ValueLayout.JAVA_BYTE, curr); + while (val != ';') { + buf[i++] = val; + curr++; + val = ms.get(ValueLayout.JAVA_BYTE, curr); + } + + int sLen = i; + int hash = hash(buf, sLen); + + curr++; // skip semicolon + + long tempIdx = curr; + boolean neg = ms.get(ValueLayout.JAVA_BYTE, tempIdx) == '-'; + boolean twoDig = ms.get(ValueLayout.JAVA_BYTE, tempIdx + 1 + (neg ? 1 : 0)) == '.'; + int len = 3 + (neg ? 1 : 0) + (twoDig ? 0 : 1); + int d0 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 1)) - '0'; + int d1 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 3)) - '0'; + int base = d0 + digits1s[d1] + (twoDig ? 0 : digits10s[((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 4)) - '0']); + short temp = (short) (neg ? -base : base); + + localAgg.add(buf, sLen, temp, hash); + curr = tempIdx + len + 1; + } + } + + static int hash(byte[] buf, int sLen) { + // TODO find a hash that works directly from byte array + // if shorter than 8 chars, mask out upper bits + long mask = sLen < 8 ? -(1L << ((8 - sLen) << 3)) : 0xFFFFFFFFL; + long val = ((buf[0] & 0xffL) << 56) | ((buf[1] & 0xffL) << 48) | ((buf[2] & 0xffL) << 40) | ((buf[3] & 0xffL) << 32) | ((buf[4] & 0xffL) << 24) + | ((buf[5] & 0xffL) << 16) | ((buf[6] & 0xFFL) << 8) | (buf[7] & 0xffL); + val &= mask; + + // also worth trying: https://lemire.me/blog/2015/10/22/faster-hashing-without-effort/ + // lemire: https://lemire.me/blog/2023/07/14/recognizing-string-prefixes-with-simd-instructions/ + int hash = (int) (((((val >> 32) ^ val) & 0xffffffffL) * 3523216699L) >> 32); + return hash; + } + + static void processRangeSIMD(MemorySegment ms, boolean frontPad, boolean backPad, long start, long end, final OpenHashTable localAgg) { + byte[] buf = new byte[128]; + + long curr = frontPad ? findNextEntryStart(ms, start) : start; + long limit = end - padding; + + var needle = ByteVector.broadcast(ByteVector.SPECIES_256, ';'); + while (curr < limit) { + + int segStart = 0; + int sLen; + + while (true) { + var section = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, ms, curr + segStart, ByteOrder.LITTLE_ENDIAN); + section.intoArray(buf, segStart); + VectorMask matches = section.compare(VectorOperators.EQ, needle); + int idx = matches.firstTrue(); + if (idx != 32) { + sLen = segStart + idx; + break; + } + segStart += 32; + } + + int hash = hash(buf, sLen); + + curr += sLen; + curr++; // semicolon + + long tempIdx = curr; + boolean neg = ms.get(ValueLayout.JAVA_BYTE, tempIdx) == '-'; + boolean twoDig = ms.get(ValueLayout.JAVA_BYTE, tempIdx + 1 + (neg ? 1 : 0)) == '.'; + int len = 3 + (neg ? 1 : 0) + (twoDig ? 0 : 1); + int d0 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 1)) - '0'; + int d1 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 3)) - '0'; + int base = d0 + digits1s[d1] + (twoDig ? 0 : digits10s[((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 4)) - '0']); + short temp = (short) (neg ? -base : base); + + localAgg.add(buf, sLen, temp, hash); + curr = tempIdx + len + 1; + } + + // last batch is near end of file, process without SIMD to avoid out-of-bounds + if (!backPad) { + processRangeScalar(ms, curr, end, localAgg); + } + } + + /** + * For debugging issues with hash function + */ + static void checkHashDistributionQuality(ArrayList localAggs) { + HashSet uniquesHashValues = new HashSet(); + HashSet uniqueCities = new HashSet(); + HashMap> cityToHash = new HashMap<>(); + + for (var agg : localAggs) { + for (OpenHashTable.Entry entry : agg.entries) { + if (entry == null) { + continue; + } + uniquesHashValues.add(entry.hash); + String station = new String(entry.key, StandardCharsets.UTF_8); // for UTF-8 encoding + uniqueCities.add(station); + + if (!cityToHash.containsKey(station)) { + cityToHash.put(station, new HashSet<>()); + } + cityToHash.get(station).add(entry.hash); + } + } + + for (var pair : cityToHash.entrySet()) { + if (pair.getValue().size() > 1) { + System.err.println("multiple hashes: " + pair.getKey() + " " + pair.getValue()); + } + } + + System.err.println("Unique stations: " + uniqueCities.size() + ", unique hash values: " + uniquesHashValues.size()); + } + + /** + * Combine thread local values + */ + static HashMap mergeAggregations(ArrayList localAggs) { + HashMap global = new HashMap<>(); + for (var agg : localAggs) { + for (OpenHashTable.Entry entry : agg.entries) { + if (entry == null) { + continue; + } + String station = new String(entry.key, StandardCharsets.UTF_8); // for UTF-8 encoding + var currentVal = global.get(station); + if (currentVal != null) { + currentVal.merge(entry); + } + else { + global.put(station, entry); + } + } + } + return global; + } + + static final long batchSize = 10_000_000; + + static final int padding = 200; // max entry size is 107ish == 100 (station) + 1 (semicolon) + 5 (temp, eg -99.9) + 1 (newline) + + public static void main(String[] args) throws IOException, InterruptedException { + RandomAccessFile file = new RandomAccessFile(FILE, "r"); + FileChannel channel = file.getChannel(); + + int numThreads = Runtime.getRuntime().availableProcessors(); + + final long fileSize = channel.size(); + final MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()); + final ArrayList localAggs = new ArrayList<>(numThreads); + Thread[] threads = new Thread[numThreads]; + final AtomicLong progress = new AtomicLong(0); + + class Task implements Runnable { + final int threadId; + + Task(int threadId) { + this.threadId = threadId; + } + + @Override + public void run() { + var localAgg = localAggs.get(threadId); + while (true) { + final long startBatch = progress.getAndAdd(batchSize); + if (startBatch >= fileSize) { + break; + } + final long endBatch = Math.min(startBatch + batchSize, fileSize); + final boolean first = startBatch == 0; + final boolean frontPad = !first; + final boolean last = endBatch == fileSize; + final boolean backPad = !last; + processRangeSIMD(ms, frontPad, backPad, startBatch, endBatch, localAgg); + } + } + } + + for (int t = 0; t < numThreads; t++) { + localAggs.add(new OpenHashTable()); + threads[t] = new Thread(new Task(t), "Thread-" + t); + threads[t].start(); + } + + for (var thread : threads) { + thread.join(); + } + + var globalAggs = mergeAggregations(localAggs); + + Map res = new TreeMap<>(); + for (Map.Entry entry : globalAggs.entrySet()) { + final var ma = entry.getValue(); + res.put(entry.getKey(), new ResultRow(ma.min / 10.0, (ma.sum / 10.0) / ma.count, ma.max / 10.0)); + } + System.out.println(res); + } +}