10k improvement (#419)
* Remove commented-out params from the script * General cleanup and refactoring * Deoptimize parseTemperatureSimple * Optimize nameEquals()
This commit is contained in:
		| @@ -155,39 +155,52 @@ public class CalculateAverage_mtopolnik { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private static final int MAX_TEMPERATURE_LEN = 5; | ||||
|         private static final int MAX_ROW_LEN = MAX_NAME_LEN + 1 + MAX_TEMPERATURE_LEN + 1; | ||||
|         private static final long DANGER_ZONE_LENGTH = ((MAX_ROW_LEN - 1) / 8 * 8 + 8); | ||||
|  | ||||
|         private void processChunk() { | ||||
|             while (cursor < inputSize) { | ||||
|                 boolean withinSafeZone; | ||||
|                 long word1; | ||||
|                 long word2; | ||||
|                 if (cursor + 2 * Long.BYTES <= inputSize) { | ||||
|                     word1 = UNSAFE.getLong(inputBase + cursor); | ||||
|                     word2 = UNSAFE.getLong(inputBase + cursor + Long.BYTES); | ||||
|                 long nameLen; | ||||
|                 long nameStartAddress = inputBase + cursor; | ||||
|                 if (cursor + DANGER_ZONE_LENGTH <= inputSize) { | ||||
|                     withinSafeZone = true; | ||||
|                     word1 = UNSAFE.getLong(nameStartAddress); | ||||
|                     word2 = UNSAFE.getLong(nameStartAddress + Long.BYTES); | ||||
|                     nameLen = nameLen(word1, word2, withinSafeZone); | ||||
|                     word1 = maskWord(word1, nameLen); | ||||
|                     word2 = maskWord(word2, nameLen - Long.BYTES); | ||||
|                 } | ||||
|                 else { | ||||
|                     withinSafeZone = false; | ||||
|                     UNSAFE.putLong(nameBufBase, 0); | ||||
|                     UNSAFE.putLong(nameBufBase + Long.BYTES, 0); | ||||
|                     UNSAFE.copyMemory(inputBase + cursor, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor)); | ||||
|                     UNSAFE.copyMemory(nameStartAddress, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor)); | ||||
|                     word1 = UNSAFE.getLong(nameBufBase); | ||||
|                     word2 = UNSAFE.getLong(nameBufBase + Long.BYTES); | ||||
|                     nameLen = nameLen(word1, word2, withinSafeZone); | ||||
|                 } | ||||
|                 long posOfSemicolon = posOfSemicolon(word1, word2); | ||||
|                 word1 = maskWord(word1, posOfSemicolon - cursor); | ||||
|                 word2 = maskWord(word2, posOfSemicolon - cursor - Long.BYTES); | ||||
|                 long hash = hash(word1); | ||||
|                 long namePos = cursor; | ||||
|                 long nameLen = posOfSemicolon - cursor; | ||||
|                 assert nameLen <= 100 : "nameLen > 100"; | ||||
|                 int temperature = parseTemperatureAndAdvanceCursor(posOfSemicolon); | ||||
|                 updateStats(hash, namePos, nameLen, word1, word2, temperature); | ||||
|                 assert nameLen > 0 && nameLen <= 100 : nameLen; | ||||
|                 long tempStartAddress = nameStartAddress + nameLen + 1; | ||||
|                 int temperature = withinSafeZone | ||||
|                         ? parseTemperatureSwarAndAdvanceCursor(tempStartAddress) | ||||
|                         : parseTemperatureSimpleAndAdvanceCursor(tempStartAddress); | ||||
|                 updateStats(hash, nameStartAddress, nameLen, word1, word2, temperature, withinSafeZone); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private void updateStats(long hash, long namePos, long nameLen, long nameWord1, long nameWord2, int temperature) { | ||||
|         private void updateStats( | ||||
|                                  long hash, long nameStartAddress, long nameLen, long nameWord1, long nameWord2, | ||||
|                                  int temperature, boolean withinSafeZone) { | ||||
|             int tableIndex = (int) (hash & TABLE_INDEX_MASK); | ||||
|             while (true) { | ||||
|                 stats.gotoIndex(tableIndex); | ||||
|                 if (stats.hash() == hash && stats.nameLen() == nameLen | ||||
|                         && nameEquals(stats.nameAddress(), inputBase + namePos, nameLen, nameWord1, nameWord2)) { | ||||
|                 if (stats.hash() == hash && stats.nameLen() == nameLen && nameEquals( | ||||
|                         stats.nameAddress(), nameStartAddress, nameLen, nameWord1, nameWord2, withinSafeZone)) { | ||||
|                     stats.setSum(stats.sum() + temperature); | ||||
|                     stats.setCount(stats.count() + 1); | ||||
|                     stats.setMin((short) Integer.min(stats.min(), temperature)); | ||||
| @@ -204,72 +217,58 @@ public class CalculateAverage_mtopolnik { | ||||
|                 stats.setCount(1); | ||||
|                 stats.setMin((short) temperature); | ||||
|                 stats.setMax((short) temperature); | ||||
|                 UNSAFE.copyMemory(inputBase + namePos, stats.nameAddress(), nameLen); | ||||
|                 UNSAFE.copyMemory(nameStartAddress, stats.nameAddress(), nameLen); | ||||
|                 return; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private int parseTemperatureAndAdvanceCursor(long semicolonPos) { | ||||
|             long startOffset = semicolonPos + 1; | ||||
|             if (startOffset <= inputSize - Long.BYTES) { | ||||
|                 return parseTemperatureSwarAndAdvanceCursor(startOffset); | ||||
|             } | ||||
|             return parseTemperatureSimpleAndAdvanceCursor(startOffset); | ||||
|         } | ||||
|  | ||||
|         // Credit: merykitty | ||||
|         private int parseTemperatureSwarAndAdvanceCursor(long startOffset) { | ||||
|             long word = UNSAFE.getLong(inputBase + startOffset); | ||||
|         private int parseTemperatureSwarAndAdvanceCursor(long tempStartAddress) { | ||||
|             long word = UNSAFE.getLong(tempStartAddress); | ||||
|             final long negated = ~word; | ||||
|             final int dotPos = Long.numberOfTrailingZeros(negated & 0x10101000); | ||||
|             cursor = (tempStartAddress + (dotPos / 8) + 3) - inputBase; | ||||
|             final long signed = (negated << 59) >> 63; | ||||
|             final long removeSignMask = ~(signed & 0xFF); | ||||
|             final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L; | ||||
|             final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; | ||||
|             final int temperature = (int) ((absValue ^ signed) - signed); | ||||
|             cursor = startOffset + (dotPos / 8) + 3; | ||||
|             return temperature; | ||||
|             return (int) ((absValue ^ signed) - signed); | ||||
|         } | ||||
|  | ||||
|         private int parseTemperatureSimpleAndAdvanceCursor(long startOffset) { | ||||
|         private int parseTemperatureSimpleAndAdvanceCursor(long tempStartAddress) { | ||||
|             final byte minus = (byte) '-'; | ||||
|             final byte zero = (byte) '0'; | ||||
|             final byte dot = (byte) '.'; | ||||
|  | ||||
|             // Temperature plus the following newline is at least 4 chars, so this is always safe: | ||||
|             int fourCh = UNSAFE.getInt(inputBase + startOffset); | ||||
|             final int mask = 0xFF; | ||||
|             byte ch = (byte) (fourCh & mask); | ||||
|             int shift = 0; | ||||
|             byte ch = UNSAFE.getByte(tempStartAddress); | ||||
|             long address = tempStartAddress; | ||||
|             int temperature; | ||||
|             int sign; | ||||
|             if (ch == minus) { | ||||
|                 sign = -1; | ||||
|                 shift += 8; | ||||
|                 ch = (byte) ((fourCh & (mask << shift)) >>> shift); | ||||
|                 address++; | ||||
|                 ch = UNSAFE.getByte(address); | ||||
|             } | ||||
|             else { | ||||
|                 sign = 1; | ||||
|             } | ||||
|             temperature = ch - zero; | ||||
|             shift += 8; | ||||
|             ch = (byte) ((fourCh & (mask << shift)) >>> shift); | ||||
|             address++; | ||||
|             ch = UNSAFE.getByte(address); | ||||
|             if (ch == dot) { | ||||
|                 shift += 8; | ||||
|                 ch = (byte) ((fourCh & (mask << shift)) >>> shift); | ||||
|                 address++; | ||||
|                 ch = UNSAFE.getByte(address); | ||||
|             } | ||||
|             else { | ||||
|                 temperature = 10 * temperature + (ch - zero); | ||||
|                 shift += 16; | ||||
|                 // The last character may be past the four loaded bytes, load it from memory. | ||||
|                 // Checking that with another `if` is self-defeating for performance. | ||||
|                 ch = UNSAFE.getByte(inputBase + startOffset + (shift / 8)); | ||||
|                 address += 2; | ||||
|                 ch = UNSAFE.getByte(address); | ||||
|             } | ||||
|             temperature = 10 * temperature + (ch - zero); | ||||
|             // `shift` holds the number of bits in the temperature field. | ||||
|             // address - inputBase is the length of the temperature field. | ||||
|             // A newline character follows the temperature, and so we advance | ||||
|             // the cursor past the newline to the start of the next line. | ||||
|             cursor = startOffset + (shift / 8) + 2; | ||||
|             cursor = (address + 2) - inputBase; | ||||
|             return sign * temperature; | ||||
|         } | ||||
|  | ||||
| @@ -286,15 +285,27 @@ public class CalculateAverage_mtopolnik { | ||||
|             return hash; | ||||
|         } | ||||
|  | ||||
|         private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2) { | ||||
|         private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2, | ||||
|                                           boolean withinSafeZone) { | ||||
|             boolean mismatch1 = maskWord(inputWord1, len) != UNSAFE.getLong(statsAddr); | ||||
|             boolean mismatch2 = maskWord(inputWord2, len - Long.BYTES) != UNSAFE.getLong(statsAddr + Long.BYTES); | ||||
|             if (mismatch1 | mismatch2) { | ||||
|                 return false; | ||||
|             if (len <= 2 * Long.BYTES) { | ||||
|                 return !(mismatch1 | mismatch2); | ||||
|             } | ||||
|             for (int i = 2 * Long.BYTES; i < len; i++) { | ||||
|                 if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) { | ||||
|                     return false; | ||||
|             if (withinSafeZone) { | ||||
|                 int i = 2 * Long.BYTES; | ||||
|                 for (; i <= len - Long.BYTES; i += Long.BYTES) { | ||||
|                     if (UNSAFE.getLong(inputAddr + i) != UNSAFE.getLong(statsAddr + i)) { | ||||
|                         return false; | ||||
|                     } | ||||
|                 } | ||||
|                 return maskWord(UNSAFE.getLong(inputAddr + i), len - i) == UNSAFE.getLong(statsAddr + i); | ||||
|             } | ||||
|             else { | ||||
|                 for (int i = 2 * Long.BYTES; i < len; i++) { | ||||
|                     if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) { | ||||
|                         return false; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             return true; | ||||
| @@ -311,44 +322,62 @@ public class CalculateAverage_mtopolnik { | ||||
|  | ||||
|         // Adapted from https://jameshfisher.com/2017/01/24/bitwise-check-for-zero-byte/ | ||||
|         // and https://github.com/ashvardanian/StringZilla/blob/14e7a78edcc16b031c06b375aac1f66d8f19d45a/stringzilla/stringzilla.h#L139-L169 | ||||
|         long posOfSemicolon(long word1, long word2) { | ||||
|             long diff = word1 ^ BROADCAST_SEMICOLON; | ||||
|             long matchBits1 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; | ||||
|             diff = word2 ^ BROADCAST_SEMICOLON; | ||||
|             long matchBits2 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; | ||||
|             if ((matchBits1 | matchBits2) != 0) { | ||||
|                 int trailing1 = Long.numberOfTrailingZeros(matchBits1); | ||||
|                 int match1IsNonZero = trailing1 & 63; | ||||
|                 match1IsNonZero |= match1IsNonZero >>> 3; | ||||
|                 match1IsNonZero |= match1IsNonZero >>> 1; | ||||
|                 match1IsNonZero |= match1IsNonZero >>> 1; | ||||
|                 // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to | ||||
|                 // raise the lowest bit in traling2 if trailing1 is nonzero. This forces | ||||
|                 // trailing2 to be zero if trailing1 is non-zero. | ||||
|                 int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63; | ||||
|                 return cursor + ((trailing1 | trailing2) >> 3); | ||||
|             } | ||||
|             long offset = cursor + 2 * Long.BYTES; | ||||
|             for (; offset <= inputSize - Long.BYTES; offset += Long.BYTES) { | ||||
|                 var block = UNSAFE.getLong(inputBase + offset); | ||||
|                 diff = block ^ BROADCAST_SEMICOLON; | ||||
|                 long matchBits = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; | ||||
|                 if (matchBits != 0) { | ||||
|                     return offset + Long.numberOfTrailingZeros(matchBits) / 8; | ||||
|         long nameLen(long word1, long word2, boolean withinSafeZone) { | ||||
|             { | ||||
|                 long matchBits1 = matchBits(word1); | ||||
|                 long matchBits2 = matchBits(word2); | ||||
|                 if ((matchBits1 | matchBits2) != 0) { | ||||
|                     int trailing1 = Long.numberOfTrailingZeros(matchBits1); | ||||
|                     int match1IsNonZero = trailing1 & 63; | ||||
|                     match1IsNonZero |= match1IsNonZero >>> 3; | ||||
|                     match1IsNonZero |= match1IsNonZero >>> 1; | ||||
|                     match1IsNonZero |= match1IsNonZero >>> 1; | ||||
|                     // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to | ||||
|                     // raise the lowest bit in trailing2 if trailing1 is nonzero. This forces | ||||
|                     // trailing2 to be zero if trailing1 is non-zero. | ||||
|                     int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63; | ||||
|                     // trailing1 | trailing2 works like trailing1 + trailing2 because if trailing2 is non-zero, | ||||
|                     // then trailing1 is 64, and since trailing2 is < 64, there's no bit overlap. | ||||
|                     return (trailing1 | trailing2) >> 3; | ||||
|                 } | ||||
|             } | ||||
|             return posOfSemicolonSimple(offset); | ||||
|             long nameStartAddress = inputBase + cursor; | ||||
|             long address = nameStartAddress + 2 * Long.BYTES; | ||||
|             long limit = inputBase + inputSize; | ||||
|             if (withinSafeZone) { | ||||
|                 for (; address < limit; address += Long.BYTES) { | ||||
|                     var block = maskWord(UNSAFE.getLong(address), limit - address); | ||||
|                     long matchBits = matchBits(block); | ||||
|                     if (matchBits != 0) { | ||||
|                         return address + (Long.numberOfTrailingZeros(matchBits) >> 3) - nameStartAddress; | ||||
|                     } | ||||
|                 } | ||||
|                 throw new RuntimeException("Semicolon not found"); | ||||
|             } | ||||
|             return addrOfSemicolonSafe(address, limit) - nameStartAddress; | ||||
|         } | ||||
|  | ||||
|         private long posOfSemicolonSimple(long offset) { | ||||
|             for (; offset < inputSize; offset++) { | ||||
|                 if (UNSAFE.getByte(inputBase + offset) == SEMICOLON) { | ||||
|                     return offset; | ||||
|         private static long addrOfSemicolonSafe(long address, long limit) { | ||||
|             for (; address < limit - Long.BYTES + 1; address += Long.BYTES) { | ||||
|                 var block = UNSAFE.getLong(address); | ||||
|                 long matchBits = matchBits(block); | ||||
|                 if (matchBits != 0) { | ||||
|                     return address + (Long.numberOfTrailingZeros(matchBits) >> 3); | ||||
|                 } | ||||
|             } | ||||
|             for (; address < limit; address++) { | ||||
|                 if (UNSAFE.getByte(address) == SEMICOLON) { | ||||
|                     return address; | ||||
|                 } | ||||
|             } | ||||
|             throw new RuntimeException("Semicolon not found"); | ||||
|         } | ||||
|  | ||||
|         private static long matchBits(long word) { | ||||
|             long diff = word ^ BROADCAST_SEMICOLON; | ||||
|             return (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; | ||||
|         } | ||||
|  | ||||
|         // Copies the results from native memory to Java heap and puts them into the results array. | ||||
|         private void exportResults() { | ||||
|             var exportedStats = new ArrayList<StationStats>(10_000); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user