Locally another 5% faster, much faster for larger set, made more general (#352)
This commit is contained in:
		| @@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets; | ||||
| import java.nio.file.Path; | ||||
| import java.nio.file.StandardOpenOption; | ||||
| import java.util.HashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import java.util.stream.Collectors; | ||||
| import java.util.stream.IntStream; | ||||
| @@ -53,8 +52,13 @@ import sun.misc.Unsafe; | ||||
|  * Various tweaks for Linux/cache    1550 ms (should/could make a difference on target machine) | ||||
|  * Improved layout/predictability:   1400 ms | ||||
|  * Delayed String creation again:    1350 ms | ||||
|  * Remove writing to buffer:         1335 ms | ||||
|  * Optimized collecting at the end:  1310 ms | ||||
|  * Adding a lot of comments:         priceless | ||||
|  * | ||||
|  * Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai for ideas. | ||||
|  * | ||||
|  * Follow me at: @royvanrijn | ||||
|  */ | ||||
| public class CalculateAverage_royvanrijn { | ||||
|  | ||||
| @@ -74,29 +78,24 @@ public class CalculateAverage_royvanrijn { | ||||
|     } | ||||
|  | ||||
|     public static void main(String[] args) throws Exception { | ||||
|  | ||||
|         // Calculate input segments. | ||||
|         final int numberOfChunks = Runtime.getRuntime().availableProcessors(); | ||||
|         final long[] chunks = getSegments(numberOfChunks); | ||||
|  | ||||
|         final List<Entry[]> repositories = IntStream.range(0, chunks.length - 1) | ||||
|         final Map<String, Entry> measurements = HashMap.newHashMap(1 << 10); | ||||
|         IntStream.range(0, chunks.length - 1) | ||||
|                 .mapToObj(chunkIndex -> processMemoryArea(chunks[chunkIndex], chunks[chunkIndex + 1])) | ||||
|                 .parallel() | ||||
|                 .toList(); | ||||
|  | ||||
|         // Sometimes simple is better: | ||||
|         final HashMap<String, Entry> measurements = HashMap.newHashMap(1 << 10); | ||||
|         for (Entry[] entries : repositories) { | ||||
|             for (Entry entry : entries) { | ||||
|                 if (entry != null) | ||||
|                     measurements.merge(extractedCityFromLongArray(entry.data, entry.length), entry, Entry::mergeWith); | ||||
|             } | ||||
|         } | ||||
|                 .forEachOrdered(repo -> { // make sure it's ordered, no concurrent map | ||||
|                     for (Entry entry : repo) { | ||||
|                         if (entry != null) | ||||
|                             measurements.merge(turnLongArrayIntoString(entry.data, entry.length), entry, Entry::mergeWith); | ||||
|                     } | ||||
|                 }); | ||||
|  | ||||
|         System.out.print("{" + | ||||
|                 measurements.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", "))); | ||||
|         System.out.println("}"); | ||||
|  | ||||
|     } | ||||
|  | ||||
|     /** | ||||
| @@ -123,15 +122,20 @@ public class CalculateAverage_royvanrijn { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static final int TABLE_SIZE = 1 << 19; // large enough for the contest. | ||||
|     private static final int TABLE_MASK = (TABLE_SIZE - 1); | ||||
|  | ||||
|     // This is where I store the hashtable entry data in the "hot loop" | ||||
|     // The long[] contains the name in bytes (yeah, confusing) | ||||
|     // I've tried flyweight-ing, carrying all the data in a single byte[], | ||||
|     // where you offset type-indices: min:int,max:int,count:int,etc. | ||||
|     // | ||||
|     // The performance was just a little worse than this simple class. | ||||
|     static final class Entry { | ||||
|         private final long[] data; | ||||
|         private int min, max, count, length; | ||||
|         private long sum; | ||||
|  | ||||
|         Entry(final long[] data, int length, int temp) { | ||||
|         private int min, max, count; | ||||
|         private byte length; | ||||
|         private long sum; | ||||
|         private final long[] data; | ||||
|  | ||||
|         Entry(final long[] data, byte length, int temp) { | ||||
|             this.data = data; | ||||
|             this.length = length; | ||||
|             this.min = temp; | ||||
| @@ -164,127 +168,161 @@ public class CalculateAverage_royvanrijn { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Delay String creation until the end: | ||||
|      * @param data | ||||
|      * @param length | ||||
|      * @return | ||||
|      */ | ||||
|     private static String extractedCityFromLongArray(final long[] data, final int length) { | ||||
|         // Initiate as late as possible: | ||||
|     // Only parse the String at the final end, when we have only the needed entries left that we need to output: | ||||
|     private static String turnLongArrayIntoString(final long[] data, final int length) { | ||||
|         // Create our target byte[] | ||||
|         final byte[] bytes = new byte[length]; | ||||
|         // The power of magic allows us to just copy the memory in there. | ||||
|         UNSAFE.copyMemory(data, Unsafe.ARRAY_LONG_BASE_OFFSET, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); | ||||
|         // And construct a String() | ||||
|         return new String(bytes, StandardCharsets.UTF_8); | ||||
|     } | ||||
|  | ||||
|     private static Entry createNewEntry(final long[] buffer, final int lengthLongs, final int lengthBytes, final int temp) { | ||||
|  | ||||
|     private static Entry createNewEntry(final long fromAddress, final int lengthLongs, final byte lengthBytes, final int temp) { | ||||
|         // Make a copy of our working buffer, store this in a new Entry: | ||||
|         final long[] bufferCopy = new long[lengthLongs]; | ||||
|         System.arraycopy(buffer, 0, bufferCopy, 0, lengthLongs); | ||||
|  | ||||
|         // Add the entry: | ||||
|         // Just copy everything over, bytes into the long[] | ||||
|         UNSAFE.copyMemory(null, fromAddress, bufferCopy, Unsafe.ARRAY_BYTE_BASE_OFFSET, lengthBytes); | ||||
|         return new Entry(bufferCopy, lengthBytes, temp); | ||||
|     } | ||||
|  | ||||
|     private static final int TABLE_SIZE = 1 << 19; | ||||
|     private static final int TABLE_MASK = (TABLE_SIZE - 1); | ||||
|  | ||||
|     private static Entry[] processMemoryArea(final long fromAddress, final long toAddress) { | ||||
|  | ||||
|         final Entry[] table = new Entry[TABLE_SIZE]; | ||||
|         final long[] buffer = new long[16]; | ||||
|  | ||||
|         long ptr = fromAddress; | ||||
|         int bufferPtr; | ||||
|         int packedBytes; | ||||
|         long hash; | ||||
|         long ptr = fromAddress; | ||||
|         long word; | ||||
|         long mask; | ||||
|  | ||||
|         final Entry[] table = new Entry[TABLE_SIZE]; | ||||
|  | ||||
|         // Go from start to finish address through the bytes: | ||||
|         while (ptr < toAddress) { | ||||
|  | ||||
|             final long startAddress = ptr; | ||||
|  | ||||
|             bufferPtr = 0; | ||||
|             hash = 1; | ||||
|             packedBytes = 1; | ||||
|             hash = 0; | ||||
|             word = UNSAFE.getLong(ptr); | ||||
|             mask = getDelimiterMask(word); | ||||
|  | ||||
|             // Removed writing to a buffer here, why would we, we know the address and we'll need to check there anyway. | ||||
|             while (mask == 0) { | ||||
|                 buffer[bufferPtr++] = word; | ||||
|                 // If the mask is zero, we have no ';' | ||||
|                 packedBytes++; | ||||
|                 // So we continue building the hash: | ||||
|                 hash ^= word; | ||||
|                 ptr += 8; | ||||
|  | ||||
|                 // And getting a new value and mask: | ||||
|                 word = UNSAFE.getLong(ptr); | ||||
|                 mask = getDelimiterMask(word); | ||||
|             } | ||||
|  | ||||
|             // Found delimiter: | ||||
|             final long delimiterAddress = ptr + (Long.numberOfTrailingZeros(mask) >> 3); | ||||
|             final long numberBits = UNSAFE.getLong(delimiterAddress + 1); | ||||
|             final int delimiterByte = Long.numberOfTrailingZeros(mask); | ||||
|             final long delimiterAddress = ptr + (delimiterByte >> 3); | ||||
|  | ||||
|             // Finish the masks and hash: | ||||
|             word = word & ((mask >> 7) - 1); | ||||
|             buffer[bufferPtr++] = word; | ||||
|             hash ^= word; | ||||
|             final long partialWord = word & ((mask >>> 7) - 1); | ||||
|             hash ^= partialWord; | ||||
|  | ||||
|             final long invNumberBits = ~numberBits; | ||||
|             final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBits & DOT_BITS); | ||||
|             // Read a long value from memory starting from the delimiter + 1, the number part: | ||||
|             final long numberBytes = UNSAFE.getLong(delimiterAddress + 1); | ||||
|             final long invNumberBytes = ~numberBytes; | ||||
|  | ||||
|             // Update counter asap, lets CPU predict. | ||||
|             // Adjust our pointer | ||||
|             final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBytes & DOT_BITS); | ||||
|             ptr = delimiterAddress + (decimalSepPos >> 3) + 4; | ||||
|  | ||||
|             // Awesome idea of merykitty: | ||||
|             final int temp = extractTemp(numberBits, invNumberBits, decimalSepPos); | ||||
|  | ||||
|             int intHash = (int) (hash ^ (hash >>> 33)); // offset for extra entropy | ||||
|             // Calculate the final hash and index of the table: | ||||
|             int intHash = (int) (hash ^ (hash >> 32)); | ||||
|             intHash = intHash ^ (intHash >> 17); | ||||
|             int index = intHash & TABLE_MASK; | ||||
|  | ||||
|             // Find or insert the entry: | ||||
|             while (true) { | ||||
|                 Entry tableEntry = table[index]; | ||||
|                 if (tableEntry == null) { | ||||
|                     final int length = (int) (delimiterAddress - startAddress); | ||||
|                     table[index] = createNewEntry(buffer, bufferPtr, length, temp); | ||||
|                     final int temp = extractTemp(decimalSepPos, invNumberBytes, numberBytes); | ||||
|                     // Create a new entry: | ||||
|                     final byte length = (byte) (delimiterAddress - startAddress); | ||||
|                     table[index] = createNewEntry(startAddress, packedBytes, length, temp); | ||||
|                     break; | ||||
|                 } | ||||
|                 else if (bufferPtr == tableEntry.data.length) { | ||||
|                     if (!arrayEquals(buffer, tableEntry.data, bufferPtr)) { | ||||
|                         index = (index + 1) & TABLE_MASK; | ||||
|                         continue; | ||||
|                     } | ||||
|                     // No differences in array | ||||
|                 // Don't bother re-checking things here like hash or length. | ||||
|                 // we'll need to check the content anyway if it's a hit, which is most times | ||||
|                 else if (memoryEqualsEntry(startAddress, tableEntry.data, partialWord, packedBytes)) { | ||||
|                     // temperature, you're not temporary my friend | ||||
|                     final int temp = extractTemp(decimalSepPos, invNumberBytes, numberBytes); | ||||
|                     // No differences, same entry: | ||||
|                     tableEntry.updateWith(temp); | ||||
|                     break; | ||||
|                 } | ||||
|                 // Move to the next index | ||||
|                 // Move to the next in the table, linear probing: | ||||
|                 index = (index + 1) & TABLE_MASK; | ||||
|             } | ||||
|         } | ||||
|         return table; | ||||
|     } | ||||
|  | ||||
|     private static int extractTemp(final long numberBits, final long invNumberBits, final int decimalSepPos) { | ||||
|     /* | ||||
|      * `___` ___ ___ _ ___` ` ___ ` _ ` _ ` _` ___ | ||||
|      * / ` \| _ \ __| \| \ \ / /_\ | | | | | | __| | ||||
|      * | () | _ / __|| . |\ V / _ \| |_| |_| | ._| | ||||
|      * \___/|_| |___|_|\_| \_/_/ \_\___|\___/|___| | ||||
|      * ---------------- BETTER SOFTWARE, FASTER -- | ||||
|      * | ||||
|      * https://www.openvalue.eu/ | ||||
|      * | ||||
|      * Made you look. | ||||
|      * | ||||
|      */ | ||||
|  | ||||
|     private static final long DOT_BITS = 0x10101000; | ||||
|     private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); | ||||
|  | ||||
|     private static int extractTemp(final int decimalSepPos, final long invNumberBits, final long numberBits) { | ||||
|         // Awesome idea of merykitty: | ||||
|         int min28 = (28 - decimalSepPos); | ||||
|         // Calculates the sign | ||||
|         final long signed = (invNumberBits << 59) >> 63; | ||||
|         final long minusFilter = ~(signed & 0xFF); | ||||
|         final long digits = ((numberBits & minusFilter) << (28 - decimalSepPos)) & 0x0F000F0F00L; | ||||
|         final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; // filter just the result | ||||
|         // Use the pre-calculated decimal position to adjust the values | ||||
|         final long digits = ((numberBits & minusFilter) << min28) & 0x0F000F0F00L; | ||||
|         // Multiply by a magic (100 * 0x1000000 + 10 * 0x10000 + 1), to get the result | ||||
|         final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; | ||||
|         // And perform abs() | ||||
|         final int temp = (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick | ||||
|         return temp; | ||||
|     } | ||||
|  | ||||
|     private static long getDelimiterMask(final long word) { | ||||
|         long match = word ^ SEPARATOR_PATTERN; | ||||
|         return (match - 0x0101010101010101L) & ~match & 0x8080808080808080L; | ||||
|     } | ||||
|  | ||||
|     private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL; | ||||
|     private static final long DOT_BITS = 0x10101000; | ||||
|     private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); | ||||
|  | ||||
|     // Takes a long and finds the bytes where this exact pattern is present. | ||||
|     // Cool bit manipulation technique: SWAR (SIMD as a Register). | ||||
|     private static long getDelimiterMask(final long word) { | ||||
|         final long match = word ^ SEPARATOR_PATTERN; | ||||
|         return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L); | ||||
|         // I've put some brackets separating the first and second part, this is faster. | ||||
|         // Now they run simultaneous after 'match' is altered, instead of waiting on each other. | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * For case multiple hashes are equal (however unlikely) check the actual key (using longs) | ||||
|      */ | ||||
|     static boolean arrayEquals(final long[] a, final long[] b, final int length) { | ||||
|         for (int i = 0; i < length; i++) { | ||||
|             if (a[i] != b[i]) | ||||
|     private static boolean memoryEqualsEntry(final long startAddress, final long[] entry, final long finalBytes, final int amountLong) { | ||||
|         for (int i = 0; i < (amountLong - 1); i++) { | ||||
|             int step = i << 3; // step by 8 bytes | ||||
|             if (UNSAFE.getLong(startAddress + step) != entry[i]) | ||||
|                 return false; | ||||
|         } | ||||
|         return true; | ||||
|         // If all previous 'whole' 8-packed byte-long values are equal | ||||
|         // We still need to check the final bytes that don't fit. | ||||
|         // and we've already calculated them for the hash. | ||||
|         return finalBytes == entry[amountLong - 1]; | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user