/* * Copyright 2023 The original authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package dev.morling.onebrc; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.VectorOperators; import jdk.incubator.vector.VectorSpecies; import sun.misc.Unsafe; import java.io.IOException; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.reflect.Field; import java.nio.ByteOrder; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.*; /** * Changelog (based on Macbook Pro Intel i7 6-cores 2.6GHz): * * Initial 40000 ms * Parse key as byte vs string 30000 ms * Parse temp as fixed vs double 15000 ms * HashMap optimization 10000 ms * Simd + reduce memory copy 8000 ms * */ public class CalculateAverage_jincongho { private static final String FILE = "./measurements.txt"; private static final Unsafe UNSAFE = initUnsafe(); private static Unsafe initUnsafe() { try { Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); theUnsafe.setAccessible(true); return (Unsafe) theUnsafe.get(Unsafe.class); } catch (NoSuchFieldException | IllegalAccessException e) { throw new RuntimeException(e); } } /** * Vectorization utilities with 1BRC-specific optimizations */ protected static class VectorUtils { // key length is usually less than 32 bytes, having more is just expensive public static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_256; /** Vectorized field delimiter search **/ public static int findDelimiter(MemorySegment data, long offset) { return ByteVector.fromMemorySegment(VectorUtils.BYTE_SPECIES, data, offset, ByteOrder.nativeOrder()) .compare(VectorOperators.EQ, ';') .firstTrue(); } /** Vectorized Hashing (explicit vectorization seems slower, overkill?) **/ // private static int[] HASH_ARRAY = initHashArray(); // private static final IntVector HASH_VECTOR = IntVector.fromArray(IntVector.SPECIES_256, HASH_ARRAY, 0); // private static final int HASH_ACCUM = HASH_ARRAY[0] * 31; // // private static int[] initHashArray() { // int[] x = new int[IntVector.SPECIES_256.length()]; // x[x.length - 1] = 1; // for (int i = x.length - 2; i >= 0; i--) // x[i] = x[i + 1] * 31; // // return x; // } /** * Ref: https://github.com/PaulSandoz/vector-api-dev-live-10-2021/blob/main/src/main/java/jmh/BytesHashcode.java * * Essentially we are doing this calculation: * h = h * 31 * 31 * 31 * 31 * 31 * 31 * 31 * 31 + * a[i + 0] * 31 * 31 * 31 * 31 * 31 * 31 * 31 + * a[i + 1] * 31 * 31 * 31 * 31 * 31 * 31 + * a[i + 2] * 31 * 31 * 31 * 31 * 31 + * a[i + 3] * 31 * 31 * 31 * 31 + * a[i + 4] * 31 * 31 * 31 + * a[i + 5] * 31 * 31 + * a[i + 6] * 31 + * a[i + 7]; */ // public static int hashCode(MemorySegment array, long offset, short length) { // int h = 1; // long i = offset, loopBound = offset + ByteVector.SPECIES_64.loopBound(length), tailBound = offset + length; // for (; i < loopBound; i += ByteVector.SPECIES_64.length()) { // // load 8 bytes, into a 64-bit vector // ByteVector b = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, array, i, ByteOrder.nativeOrder()); // // convert 8 bytes into 8 ints (hashing calculation needs int!) // IntVector x = (IntVector) b.castShape(IntVector.SPECIES_256, 0); // h = h * HASH_ACCUM + x.mul(HASH_VECTOR).reduceLanes(VectorOperators.ADD); // } // // for (; i < tailBound; i++) { // h = 31 * h + array.get(ValueLayout.JAVA_BYTE, i); // } // return h; // } // scalar implementation // public static int hashCode(final MemorySegment array, final long offset, final short length) { // final long limit = offset + length; // int h = 1; // for (long i = offset; i < limit; i++) { // h = 31 * h + UNSAFE.getByte(array.address() + i); // } // return h; // } // fxhash public static int hashCode(final MemorySegment array, final long offset, final short length) { final int seed = 0x9E3779B9; final int rotate = 5; int x, y; if (length >= Integer.BYTES) { x = UNSAFE.getInt(array.address() + offset); y = UNSAFE.getInt(array.address() + offset + length - Integer.BYTES); } else { x = UNSAFE.getByte(array.address() + offset); y = UNSAFE.getByte(array.address() + offset + length - Byte.BYTES); } return (Integer.rotateLeft(x * seed, rotate) ^ y) * seed; } /** Vectorized Key Comparison **/ private static boolean notEquals(MemorySegment a, long aOffset, MemorySegment b, long bOffset, short length, VectorSpecies BYTE_SPECIES) { final long aLimit = aOffset + length, bLimit = bOffset + length; // main loop long loopBound = bOffset + BYTE_SPECIES.loopBound(length); for (; bOffset < loopBound; aOffset += BYTE_SPECIES.length(), bOffset += BYTE_SPECIES.length()) { ByteVector av = ByteVector.fromMemorySegment(BYTE_SPECIES, a, aOffset, ByteOrder.nativeOrder() /* , BYTE_SPECIES.indexInRange(aOffset, Math.min(aOffset + BYTE_SPECIES.length(), aLimit)) */); ByteVector bv = ByteVector.fromMemorySegment(BYTE_SPECIES, b, bOffset, ByteOrder.nativeOrder() /* , BYTE_SPECIES.indexInRange(bOffset, Math.min(bOffset + BYTE_SPECIES.length(), bLimit)) */); if (av.compare(VectorOperators.NE, bv).anyTrue()) return true; } // tail cleanup - load last N bytes with mask if (bOffset < bLimit) { ByteVector av = ByteVector.fromMemorySegment(BYTE_SPECIES, a, aOffset, ByteOrder.nativeOrder(), BYTE_SPECIES.indexInRange(aOffset, aLimit)); ByteVector bv = ByteVector.fromMemorySegment(BYTE_SPECIES, b, bOffset, ByteOrder.nativeOrder(), BYTE_SPECIES.indexInRange(bOffset, bLimit)); if (av.compare(VectorOperators.NE, bv).anyTrue()) return true; } return false; } // scalar implementation // private static boolean equals(byte[] a, int aOffset, byte[] b, int bOffset, int len) { // while (bOffset < len) // if (a[aOffset++] != b[bOffset++]) // return false; // return true; // } } /** * Measurement Hash Table (for each partition) * Uses contiguous byte array to optimize for cache-line (hopefully) * * Each entry: * - KEYS: keyLength (2 bytes) + key (100 bytes) * - VALUES: min (2 bytes) + max (2 bytes) + count (4 bytes) + sum ( 8 bytes) */ protected static class PartitionAggr { private static int MAP_SIZE = 1 << 14; // 2^14 = 16384, closes to 10000 private static int KEY_SIZE = 128; // key length (2 bytes) + key (100 bytes) private static int KEY_MASK = (MAP_SIZE - 1); private static int VALUE_SIZE = 16; // min (2 bytes) + max ( 2 bytes) + count (4 bytes) + sum (8 bytes) private MemorySegment KEYS = Arena.ofShared().allocate(MAP_SIZE * KEY_SIZE, 64); private MemorySegment VALUES = Arena.ofShared().allocate(MAP_SIZE * VALUE_SIZE, 16); public PartitionAggr() { // init min and max final long limit = VALUES.address() + (MAP_SIZE * VALUE_SIZE); for (long offset = VALUES.address(); offset < limit; offset += VALUE_SIZE) { UNSAFE.putShort(offset, Short.MAX_VALUE); UNSAFE.putShort(offset + 2, Short.MIN_VALUE); } } public void update(MemorySegment key, long keyStart, short keyLength, int keyHash, short value) { int index = keyHash & KEY_MASK; long keyOffset = KEYS.address() + (index * KEY_SIZE); while (((UNSAFE.getShort(keyOffset) != keyLength) || VectorUtils.notEquals(KEYS, ((index * KEY_SIZE) + 2), key, keyStart, keyLength, VectorUtils.BYTE_SPECIES))) { if (UNSAFE.getShort(keyOffset) == 0) { // put key UNSAFE.putShort(keyOffset, keyLength); MemorySegment.copy(key, keyStart, KEYS, (index * KEY_SIZE) + 2, keyLength); break; } else { index = (index + 1) & KEY_MASK; keyOffset = KEYS.address() + (index * KEY_SIZE); } } long valueOffset = VALUES.address() + (index * VALUE_SIZE); UNSAFE.putShort(valueOffset, (short) Math.min(UNSAFE.getShort(valueOffset), value)); valueOffset += 2; UNSAFE.putShort(valueOffset, (short) Math.max(UNSAFE.getShort(valueOffset), value)); valueOffset += 2; UNSAFE.putInt(valueOffset, UNSAFE.getInt(valueOffset) + 1); valueOffset += 4; UNSAFE.putLong(valueOffset, UNSAFE.getLong(valueOffset) + value); } public void mergeTo(ResultAggr result) { long keyOffset; short keyLength; for (int i = 0; i < MAP_SIZE; i++) { // extract key keyOffset = KEYS.address() + (i * KEY_SIZE); if ((keyLength = UNSAFE.getShort(keyOffset)) == 0) continue; // extract values (if key is not null) final long valueOffset = VALUES.address() + (i * VALUE_SIZE); result.compute(new ResultAggr.ByteKey(KEYS, (i * KEY_SIZE) + 2, keyLength), (k, v) -> { if (v == null) { v = new ResultAggr.Measurement(); } v.min = (short) Math.min(UNSAFE.getShort(valueOffset), v.min); v.max = (short) Math.max(UNSAFE.getShort(valueOffset + 2), v.max); v.count += UNSAFE.getInt(valueOffset + 4); v.sum += UNSAFE.getLong(valueOffset + 8); return v; }); } } } /** * Measurement Aggregation (for all partitions) * Simple Concurrent Hash Table so all partitions can merge concurrently */ protected static class ResultAggr extends HashMap { public static class ByteKey implements Comparable { private final MemorySegment data; private final long offset; private final short length; private String str; public ByteKey(MemorySegment data, long offset, short length) { this.data = data; this.offset = offset; this.length = length; } @Override public boolean equals(Object other) { return (length == ((ByteKey) other).length) && !VectorUtils.notEquals(data, offset, ((ByteKey) other).data, ((ByteKey) other).offset, length, VectorUtils.BYTE_SPECIES); } @Override public int hashCode() { return VectorUtils.hashCode(data, offset, length); } @Override public String toString() { if (str == null) { // finally has to do a copy! byte[] copy = new byte[length]; MemorySegment.copy(data, offset, MemorySegment.ofArray(copy), 0, length); str = new String(copy, StandardCharsets.UTF_8); } return str; } @Override public int compareTo(ByteKey o) { return toString().compareTo(o.toString()); } } protected static class Measurement { public short min = Short.MAX_VALUE; public short max = Short.MIN_VALUE; public int count = 0; public long sum = 0; @Override public String toString() { return ((double) min / 10) + "/" + (Math.round((1.0 * sum) / count) / 10.0) + "/" + ((double) max / 10); } } public ResultAggr(int initialCapacity, float loadFactor) { super(initialCapacity, loadFactor); } public Map toSorted() { return new TreeMap(this); } } protected static class Partition implements Runnable { private final MemorySegment data; private long offset; private final long limit; private final PartitionAggr result; public Partition(MemorySegment data, long offset, long limit, PartitionAggr result) { this.data = data; this.offset = offset; this.limit = limit; this.result = result; } @Override public void run() { // measurement parsing final PartitionAggr aggr = this.result; // main loop (vectorized) final long loopLimit = limit - (VectorUtils.BYTE_SPECIES.length() * Math.ceilDiv(100, VectorUtils.BYTE_SPECIES.length()) + Long.BYTES); while (offset < loopLimit) { long offsetStart = offset; // find station name upto ";" int found; do { found = VectorUtils.findDelimiter(data, offset); offset += found; } while (found == VectorUtils.BYTE_SPECIES.length()); short stationLength = (short) (offset - offsetStart); int stationHash = VectorUtils.hashCode(data, offsetStart, stationLength); // find measurement upto "\n" (credit: merykitty) long numberBits = UNSAFE.getLong(data.address() + ++offset); final long invNumberBits = ~numberBits; final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBits & 0x10101000); int shift = 28 - decimalSepPos; long signed = (invNumberBits << 59) >> 63; long designMask = ~(signed & 0xFF); long digits = ((numberBits & designMask) << shift) & 0x0F000F0F00L; long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; short fixed = (short) ((absValue ^ signed) - signed); offset += (decimalSepPos >>> 3) + 3; // update measurement aggr.update(data, offsetStart, stationLength, stationHash, fixed); } // tail loop (simple) while (offset < limit) { long offsetStart = offset; // find station name upto ";" short stationLength = 0; while (UNSAFE.getByte(data.address() + offset++) != ';') stationLength++; int stationHash = VectorUtils.hashCode(data, offsetStart, stationLength); // find measurement upto "\n" byte tempBuffer = UNSAFE.getByte(data.address() + offset++); boolean isNegative = (tempBuffer == '-'); short fixed = (short) (isNegative ? 0 : (tempBuffer - '0')); while (true) { tempBuffer = UNSAFE.getByte(data.address() + offset++); if (tempBuffer == '.') { fixed = (short) (fixed * 10 + (UNSAFE.getByte(data.address() + offset) - '0')); offset += 2; break; } fixed = (short) (fixed * 10 + (tempBuffer - '0')); } fixed = isNegative ? (short) -fixed : fixed; // update measurement aggr.update(data, offsetStart, stationLength, stationHash, fixed); } // measurement result collection // aggr.mergeTo(result); } } public static void main(String[] args) throws IOException, InterruptedException { // long startTime = System.currentTimeMillis(); try (FileChannel fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), EnumSet.of(StandardOpenOption.READ)); Arena arena = Arena.ofShared()) { // scan data MemorySegment data = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size(), arena); final int processors = Runtime.getRuntime().availableProcessors(); // partition split long[] partition = new long[processors + 1]; long partitionSize = Math.ceilDiv(data.byteSize(), processors); for (int i = 0; i < processors; i++) { partition[i + 1] = partition[i] + partitionSize; if (partition[i + 1] >= data.byteSize()) { partition[i + 1] = data.byteSize(); break; } // note: vectorize this made performance worse :( while (UNSAFE.getByte(data.address() + partition[i + 1]++) != '\n') ; } // partition aggregation var threadList = new Thread[processors]; PartitionAggr[] partAggrs = new PartitionAggr[processors]; for (int i = 0; i < processors; i++) { if (partition[i] == data.byteSize()) break; partAggrs[i] = new PartitionAggr(); threadList[i] = new Thread(new Partition(data, partition[i], partition[i + 1], partAggrs[i])); threadList[i].start(); } // result ResultAggr result = new ResultAggr(1 << 14, 1); for (int i = 0; i < processors; i++) { if (partition[i] == data.byteSize()) break; threadList[i].join(); partAggrs[i].mergeTo(result); } System.out.println(result.toSorted()); } // long elapsed = System.currentTimeMillis() - startTime; // System.out.println("Elapsed: " + ((double) elapsed / 1000.0)); } /** Unit Tests **/ public static void testMain(String[] args) { testHashCode(); testNotEquals(); } private static void testHashCode() { // test key length from 1 to 100 for (int i = 1; i <= 100; i++) { byte[] array = new byte[i]; for (int j = 0; j < i; j++) array[j] = (byte) j; // compare with java default implementation assertTrue(VectorUtils.hashCode(MemorySegment.ofArray(array), 0, (short) i) == Arrays.hashCode(array)); } } private static void testNotEquals() { byte[] a = new byte[128]; byte[] b = new byte[128]; // all equals for (int i = 1; i < 100; i++) { a[(i + 2) - 1] = 0; b[i - 1] = 0; a[(i + 2)] = 10; b[i] = 10; assertTrue(!VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_64)); assertTrue(!VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_128)); assertTrue(!VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_256)); assertTrue(!VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_512)); } // one el not equals for (int i = 1; i < 100; i++) { a[(i + 2) - 1] = 0; b[i - 1] = 0; a[(i + 2)] = 20; b[i] = 10; assertTrue(VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_64)); assertTrue(VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_128)); assertTrue(VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_256)); assertTrue(VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_512)); } } private static void assertTrue(boolean condition) { if (!condition) { throw new RuntimeException("Failed test"); } } }