From 9227aa50624fcb86b5afaab9769d2dc075a605bb Mon Sep 17 00:00:00 2001 From: Roy van Rijn Date: Fri, 12 Jan 2024 21:00:12 +0100 Subject: [PATCH] Locally another 5% faster, much faster for larger set, made more general (#352) --- .../onebrc/CalculateAverage_royvanrijn.java | 192 +++++++++++------- 1 file changed, 115 insertions(+), 77 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java index f1e8303..307833f 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java @@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -53,8 +52,13 @@ import sun.misc.Unsafe; * Various tweaks for Linux/cache 1550 ms (should/could make a difference on target machine) * Improved layout/predictability: 1400 ms * Delayed String creation again: 1350 ms + * Remove writing to buffer: 1335 ms + * Optimized collecting at the end: 1310 ms + * Adding a lot of comments: priceless * * Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai for ideas. + * + * Follow me at: @royvanrijn */ public class CalculateAverage_royvanrijn { @@ -74,29 +78,24 @@ public class CalculateAverage_royvanrijn { } public static void main(String[] args) throws Exception { - // Calculate input segments. final int numberOfChunks = Runtime.getRuntime().availableProcessors(); final long[] chunks = getSegments(numberOfChunks); - final List repositories = IntStream.range(0, chunks.length - 1) + final Map measurements = HashMap.newHashMap(1 << 10); + IntStream.range(0, chunks.length - 1) .mapToObj(chunkIndex -> processMemoryArea(chunks[chunkIndex], chunks[chunkIndex + 1])) .parallel() - .toList(); - - // Sometimes simple is better: - final HashMap measurements = HashMap.newHashMap(1 << 10); - for (Entry[] entries : repositories) { - for (Entry entry : entries) { - if (entry != null) - measurements.merge(extractedCityFromLongArray(entry.data, entry.length), entry, Entry::mergeWith); - } - } + .forEachOrdered(repo -> { // make sure it's ordered, no concurrent map + for (Entry entry : repo) { + if (entry != null) + measurements.merge(turnLongArrayIntoString(entry.data, entry.length), entry, Entry::mergeWith); + } + }); System.out.print("{" + measurements.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", "))); System.out.println("}"); - } /** @@ -123,15 +122,20 @@ public class CalculateAverage_royvanrijn { } } - private static final int TABLE_SIZE = 1 << 19; // large enough for the contest. - private static final int TABLE_MASK = (TABLE_SIZE - 1); - + // This is where I store the hashtable entry data in the "hot loop" + // The long[] contains the name in bytes (yeah, confusing) + // I've tried flyweight-ing, carrying all the data in a single byte[], + // where you offset type-indices: min:int,max:int,count:int,etc. + // + // The performance was just a little worse than this simple class. static final class Entry { - private final long[] data; - private int min, max, count, length; - private long sum; - Entry(final long[] data, int length, int temp) { + private int min, max, count; + private byte length; + private long sum; + private final long[] data; + + Entry(final long[] data, byte length, int temp) { this.data = data; this.length = length; this.min = temp; @@ -164,127 +168,161 @@ public class CalculateAverage_royvanrijn { } } - /** - * Delay String creation until the end: - * @param data - * @param length - * @return - */ - private static String extractedCityFromLongArray(final long[] data, final int length) { - // Initiate as late as possible: + // Only parse the String at the final end, when we have only the needed entries left that we need to output: + private static String turnLongArrayIntoString(final long[] data, final int length) { + // Create our target byte[] final byte[] bytes = new byte[length]; + // The power of magic allows us to just copy the memory in there. UNSAFE.copyMemory(data, Unsafe.ARRAY_LONG_BASE_OFFSET, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); + // And construct a String() return new String(bytes, StandardCharsets.UTF_8); } - private static Entry createNewEntry(final long[] buffer, final int lengthLongs, final int lengthBytes, final int temp) { - + private static Entry createNewEntry(final long fromAddress, final int lengthLongs, final byte lengthBytes, final int temp) { + // Make a copy of our working buffer, store this in a new Entry: final long[] bufferCopy = new long[lengthLongs]; - System.arraycopy(buffer, 0, bufferCopy, 0, lengthLongs); - - // Add the entry: + // Just copy everything over, bytes into the long[] + UNSAFE.copyMemory(null, fromAddress, bufferCopy, Unsafe.ARRAY_BYTE_BASE_OFFSET, lengthBytes); return new Entry(bufferCopy, lengthBytes, temp); } + private static final int TABLE_SIZE = 1 << 19; + private static final int TABLE_MASK = (TABLE_SIZE - 1); + private static Entry[] processMemoryArea(final long fromAddress, final long toAddress) { - final Entry[] table = new Entry[TABLE_SIZE]; - final long[] buffer = new long[16]; - - long ptr = fromAddress; - int bufferPtr; + int packedBytes; long hash; + long ptr = fromAddress; long word; long mask; + final Entry[] table = new Entry[TABLE_SIZE]; + + // Go from start to finish address through the bytes: while (ptr < toAddress) { final long startAddress = ptr; - bufferPtr = 0; - hash = 1; + packedBytes = 1; + hash = 0; word = UNSAFE.getLong(ptr); mask = getDelimiterMask(word); + // Removed writing to a buffer here, why would we, we know the address and we'll need to check there anyway. while (mask == 0) { - buffer[bufferPtr++] = word; + // If the mask is zero, we have no ';' + packedBytes++; + // So we continue building the hash: hash ^= word; ptr += 8; + // And getting a new value and mask: word = UNSAFE.getLong(ptr); mask = getDelimiterMask(word); } + // Found delimiter: - final long delimiterAddress = ptr + (Long.numberOfTrailingZeros(mask) >> 3); - final long numberBits = UNSAFE.getLong(delimiterAddress + 1); + final int delimiterByte = Long.numberOfTrailingZeros(mask); + final long delimiterAddress = ptr + (delimiterByte >> 3); // Finish the masks and hash: - word = word & ((mask >> 7) - 1); - buffer[bufferPtr++] = word; - hash ^= word; + final long partialWord = word & ((mask >>> 7) - 1); + hash ^= partialWord; - final long invNumberBits = ~numberBits; - final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBits & DOT_BITS); + // Read a long value from memory starting from the delimiter + 1, the number part: + final long numberBytes = UNSAFE.getLong(delimiterAddress + 1); + final long invNumberBytes = ~numberBytes; - // Update counter asap, lets CPU predict. + // Adjust our pointer + final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBytes & DOT_BITS); ptr = delimiterAddress + (decimalSepPos >> 3) + 4; - // Awesome idea of merykitty: - final int temp = extractTemp(numberBits, invNumberBits, decimalSepPos); - - int intHash = (int) (hash ^ (hash >>> 33)); // offset for extra entropy + // Calculate the final hash and index of the table: + int intHash = (int) (hash ^ (hash >> 32)); + intHash = intHash ^ (intHash >> 17); int index = intHash & TABLE_MASK; // Find or insert the entry: while (true) { Entry tableEntry = table[index]; if (tableEntry == null) { - final int length = (int) (delimiterAddress - startAddress); - table[index] = createNewEntry(buffer, bufferPtr, length, temp); + final int temp = extractTemp(decimalSepPos, invNumberBytes, numberBytes); + // Create a new entry: + final byte length = (byte) (delimiterAddress - startAddress); + table[index] = createNewEntry(startAddress, packedBytes, length, temp); break; } - else if (bufferPtr == tableEntry.data.length) { - if (!arrayEquals(buffer, tableEntry.data, bufferPtr)) { - index = (index + 1) & TABLE_MASK; - continue; - } - // No differences in array + // Don't bother re-checking things here like hash or length. + // we'll need to check the content anyway if it's a hit, which is most times + else if (memoryEqualsEntry(startAddress, tableEntry.data, partialWord, packedBytes)) { + // temperature, you're not temporary my friend + final int temp = extractTemp(decimalSepPos, invNumberBytes, numberBytes); + // No differences, same entry: tableEntry.updateWith(temp); break; } - // Move to the next index + // Move to the next in the table, linear probing: index = (index + 1) & TABLE_MASK; } } return table; } - private static int extractTemp(final long numberBits, final long invNumberBits, final int decimalSepPos) { + /* + * `___` ___ ___ _ ___` ` ___ ` _ ` _ ` _` ___ + * / ` \| _ \ __| \| \ \ / /_\ | | | | | | __| + * | () | _ / __|| . |\ V / _ \| |_| |_| | ._| + * \___/|_| |___|_|\_| \_/_/ \_\___|\___/|___| + * ---------------- BETTER SOFTWARE, FASTER -- + * + * https://www.openvalue.eu/ + * + * Made you look. + * + */ + + private static final long DOT_BITS = 0x10101000; + private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); + + private static int extractTemp(final int decimalSepPos, final long invNumberBits, final long numberBits) { + // Awesome idea of merykitty: + int min28 = (28 - decimalSepPos); + // Calculates the sign final long signed = (invNumberBits << 59) >> 63; final long minusFilter = ~(signed & 0xFF); - final long digits = ((numberBits & minusFilter) << (28 - decimalSepPos)) & 0x0F000F0F00L; - final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; // filter just the result + // Use the pre-calculated decimal position to adjust the values + final long digits = ((numberBits & minusFilter) << min28) & 0x0F000F0F00L; + // Multiply by a magic (100 * 0x1000000 + 10 * 0x10000 + 1), to get the result + final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; + // And perform abs() final int temp = (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick return temp; } - private static long getDelimiterMask(final long word) { - long match = word ^ SEPARATOR_PATTERN; - return (match - 0x0101010101010101L) & ~match & 0x8080808080808080L; - } - private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL; - private static final long DOT_BITS = 0x10101000; - private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); + + // Takes a long and finds the bytes where this exact pattern is present. + // Cool bit manipulation technique: SWAR (SIMD as a Register). + private static long getDelimiterMask(final long word) { + final long match = word ^ SEPARATOR_PATTERN; + return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L); + // I've put some brackets separating the first and second part, this is faster. + // Now they run simultaneous after 'match' is altered, instead of waiting on each other. + } /** * For case multiple hashes are equal (however unlikely) check the actual key (using longs) */ - static boolean arrayEquals(final long[] a, final long[] b, final int length) { - for (int i = 0; i < length; i++) { - if (a[i] != b[i]) + private static boolean memoryEqualsEntry(final long startAddress, final long[] entry, final long finalBytes, final int amountLong) { + for (int i = 0; i < (amountLong - 1); i++) { + int step = i << 3; // step by 8 bytes + if (UNSAFE.getLong(startAddress + step) != entry[i]) return false; } - return true; + // If all previous 'whole' 8-packed byte-long values are equal + // We still need to check the final bytes that don't fit. + // and we've already calculated them for the hash. + return finalBytes == entry[amountLong - 1]; } }