diff --git a/calculate_average_mtopolnik.sh b/calculate_average_mtopolnik.sh index e48711a..24b5a1c 100755 --- a/calculate_average_mtopolnik.sh +++ b/calculate_average_mtopolnik.sh @@ -15,7 +15,5 @@ # limitations under the License. # -# -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_mtopolnik::recordMeasurementAndAdvanceCursor" -# -XX:InlineSmallCode=10000 -XX:-TieredCompilation -XX:CICompilerCount=2 -XX:CompileThreshold=1000\ java --enable-preview \ --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_mtopolnik diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java index fe487fc..51ea415 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java @@ -155,39 +155,52 @@ public class CalculateAverage_mtopolnik { } } + private static final int MAX_TEMPERATURE_LEN = 5; + private static final int MAX_ROW_LEN = MAX_NAME_LEN + 1 + MAX_TEMPERATURE_LEN + 1; + private static final long DANGER_ZONE_LENGTH = ((MAX_ROW_LEN - 1) / 8 * 8 + 8); + private void processChunk() { while (cursor < inputSize) { + boolean withinSafeZone; long word1; long word2; - if (cursor + 2 * Long.BYTES <= inputSize) { - word1 = UNSAFE.getLong(inputBase + cursor); - word2 = UNSAFE.getLong(inputBase + cursor + Long.BYTES); + long nameLen; + long nameStartAddress = inputBase + cursor; + if (cursor + DANGER_ZONE_LENGTH <= inputSize) { + withinSafeZone = true; + word1 = UNSAFE.getLong(nameStartAddress); + word2 = UNSAFE.getLong(nameStartAddress + Long.BYTES); + nameLen = nameLen(word1, word2, withinSafeZone); + word1 = maskWord(word1, nameLen); + word2 = maskWord(word2, nameLen - Long.BYTES); } else { + withinSafeZone = false; UNSAFE.putLong(nameBufBase, 0); UNSAFE.putLong(nameBufBase + Long.BYTES, 0); - UNSAFE.copyMemory(inputBase + cursor, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor)); + UNSAFE.copyMemory(nameStartAddress, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor)); word1 = UNSAFE.getLong(nameBufBase); word2 = UNSAFE.getLong(nameBufBase + Long.BYTES); + nameLen = nameLen(word1, word2, withinSafeZone); } - long posOfSemicolon = posOfSemicolon(word1, word2); - word1 = maskWord(word1, posOfSemicolon - cursor); - word2 = maskWord(word2, posOfSemicolon - cursor - Long.BYTES); long hash = hash(word1); - long namePos = cursor; - long nameLen = posOfSemicolon - cursor; - assert nameLen <= 100 : "nameLen > 100"; - int temperature = parseTemperatureAndAdvanceCursor(posOfSemicolon); - updateStats(hash, namePos, nameLen, word1, word2, temperature); + assert nameLen > 0 && nameLen <= 100 : nameLen; + long tempStartAddress = nameStartAddress + nameLen + 1; + int temperature = withinSafeZone + ? parseTemperatureSwarAndAdvanceCursor(tempStartAddress) + : parseTemperatureSimpleAndAdvanceCursor(tempStartAddress); + updateStats(hash, nameStartAddress, nameLen, word1, word2, temperature, withinSafeZone); } } - private void updateStats(long hash, long namePos, long nameLen, long nameWord1, long nameWord2, int temperature) { + private void updateStats( + long hash, long nameStartAddress, long nameLen, long nameWord1, long nameWord2, + int temperature, boolean withinSafeZone) { int tableIndex = (int) (hash & TABLE_INDEX_MASK); while (true) { stats.gotoIndex(tableIndex); - if (stats.hash() == hash && stats.nameLen() == nameLen - && nameEquals(stats.nameAddress(), inputBase + namePos, nameLen, nameWord1, nameWord2)) { + if (stats.hash() == hash && stats.nameLen() == nameLen && nameEquals( + stats.nameAddress(), nameStartAddress, nameLen, nameWord1, nameWord2, withinSafeZone)) { stats.setSum(stats.sum() + temperature); stats.setCount(stats.count() + 1); stats.setMin((short) Integer.min(stats.min(), temperature)); @@ -204,72 +217,58 @@ public class CalculateAverage_mtopolnik { stats.setCount(1); stats.setMin((short) temperature); stats.setMax((short) temperature); - UNSAFE.copyMemory(inputBase + namePos, stats.nameAddress(), nameLen); + UNSAFE.copyMemory(nameStartAddress, stats.nameAddress(), nameLen); return; } } - private int parseTemperatureAndAdvanceCursor(long semicolonPos) { - long startOffset = semicolonPos + 1; - if (startOffset <= inputSize - Long.BYTES) { - return parseTemperatureSwarAndAdvanceCursor(startOffset); - } - return parseTemperatureSimpleAndAdvanceCursor(startOffset); - } - // Credit: merykitty - private int parseTemperatureSwarAndAdvanceCursor(long startOffset) { - long word = UNSAFE.getLong(inputBase + startOffset); + private int parseTemperatureSwarAndAdvanceCursor(long tempStartAddress) { + long word = UNSAFE.getLong(tempStartAddress); final long negated = ~word; final int dotPos = Long.numberOfTrailingZeros(negated & 0x10101000); + cursor = (tempStartAddress + (dotPos / 8) + 3) - inputBase; final long signed = (negated << 59) >> 63; final long removeSignMask = ~(signed & 0xFF); final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L; final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; - final int temperature = (int) ((absValue ^ signed) - signed); - cursor = startOffset + (dotPos / 8) + 3; - return temperature; + return (int) ((absValue ^ signed) - signed); } - private int parseTemperatureSimpleAndAdvanceCursor(long startOffset) { + private int parseTemperatureSimpleAndAdvanceCursor(long tempStartAddress) { final byte minus = (byte) '-'; final byte zero = (byte) '0'; final byte dot = (byte) '.'; - // Temperature plus the following newline is at least 4 chars, so this is always safe: - int fourCh = UNSAFE.getInt(inputBase + startOffset); - final int mask = 0xFF; - byte ch = (byte) (fourCh & mask); - int shift = 0; + byte ch = UNSAFE.getByte(tempStartAddress); + long address = tempStartAddress; int temperature; int sign; if (ch == minus) { sign = -1; - shift += 8; - ch = (byte) ((fourCh & (mask << shift)) >>> shift); + address++; + ch = UNSAFE.getByte(address); } else { sign = 1; } temperature = ch - zero; - shift += 8; - ch = (byte) ((fourCh & (mask << shift)) >>> shift); + address++; + ch = UNSAFE.getByte(address); if (ch == dot) { - shift += 8; - ch = (byte) ((fourCh & (mask << shift)) >>> shift); + address++; + ch = UNSAFE.getByte(address); } else { temperature = 10 * temperature + (ch - zero); - shift += 16; - // The last character may be past the four loaded bytes, load it from memory. - // Checking that with another `if` is self-defeating for performance. - ch = UNSAFE.getByte(inputBase + startOffset + (shift / 8)); + address += 2; + ch = UNSAFE.getByte(address); } temperature = 10 * temperature + (ch - zero); - // `shift` holds the number of bits in the temperature field. + // address - inputBase is the length of the temperature field. // A newline character follows the temperature, and so we advance // the cursor past the newline to the start of the next line. - cursor = startOffset + (shift / 8) + 2; + cursor = (address + 2) - inputBase; return sign * temperature; } @@ -286,15 +285,27 @@ public class CalculateAverage_mtopolnik { return hash; } - private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2) { + private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2, + boolean withinSafeZone) { boolean mismatch1 = maskWord(inputWord1, len) != UNSAFE.getLong(statsAddr); boolean mismatch2 = maskWord(inputWord2, len - Long.BYTES) != UNSAFE.getLong(statsAddr + Long.BYTES); - if (mismatch1 | mismatch2) { - return false; + if (len <= 2 * Long.BYTES) { + return !(mismatch1 | mismatch2); } - for (int i = 2 * Long.BYTES; i < len; i++) { - if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) { - return false; + if (withinSafeZone) { + int i = 2 * Long.BYTES; + for (; i <= len - Long.BYTES; i += Long.BYTES) { + if (UNSAFE.getLong(inputAddr + i) != UNSAFE.getLong(statsAddr + i)) { + return false; + } + } + return maskWord(UNSAFE.getLong(inputAddr + i), len - i) == UNSAFE.getLong(statsAddr + i); + } + else { + for (int i = 2 * Long.BYTES; i < len; i++) { + if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) { + return false; + } } } return true; @@ -311,44 +322,62 @@ public class CalculateAverage_mtopolnik { // Adapted from https://jameshfisher.com/2017/01/24/bitwise-check-for-zero-byte/ // and https://github.com/ashvardanian/StringZilla/blob/14e7a78edcc16b031c06b375aac1f66d8f19d45a/stringzilla/stringzilla.h#L139-L169 - long posOfSemicolon(long word1, long word2) { - long diff = word1 ^ BROADCAST_SEMICOLON; - long matchBits1 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; - diff = word2 ^ BROADCAST_SEMICOLON; - long matchBits2 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; - if ((matchBits1 | matchBits2) != 0) { - int trailing1 = Long.numberOfTrailingZeros(matchBits1); - int match1IsNonZero = trailing1 & 63; - match1IsNonZero |= match1IsNonZero >>> 3; - match1IsNonZero |= match1IsNonZero >>> 1; - match1IsNonZero |= match1IsNonZero >>> 1; - // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to - // raise the lowest bit in traling2 if trailing1 is nonzero. This forces - // trailing2 to be zero if trailing1 is non-zero. - int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63; - return cursor + ((trailing1 | trailing2) >> 3); - } - long offset = cursor + 2 * Long.BYTES; - for (; offset <= inputSize - Long.BYTES; offset += Long.BYTES) { - var block = UNSAFE.getLong(inputBase + offset); - diff = block ^ BROADCAST_SEMICOLON; - long matchBits = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; - if (matchBits != 0) { - return offset + Long.numberOfTrailingZeros(matchBits) / 8; + long nameLen(long word1, long word2, boolean withinSafeZone) { + { + long matchBits1 = matchBits(word1); + long matchBits2 = matchBits(word2); + if ((matchBits1 | matchBits2) != 0) { + int trailing1 = Long.numberOfTrailingZeros(matchBits1); + int match1IsNonZero = trailing1 & 63; + match1IsNonZero |= match1IsNonZero >>> 3; + match1IsNonZero |= match1IsNonZero >>> 1; + match1IsNonZero |= match1IsNonZero >>> 1; + // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to + // raise the lowest bit in trailing2 if trailing1 is nonzero. This forces + // trailing2 to be zero if trailing1 is non-zero. + int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63; + // trailing1 | trailing2 works like trailing1 + trailing2 because if trailing2 is non-zero, + // then trailing1 is 64, and since trailing2 is < 64, there's no bit overlap. + return (trailing1 | trailing2) >> 3; } } - return posOfSemicolonSimple(offset); + long nameStartAddress = inputBase + cursor; + long address = nameStartAddress + 2 * Long.BYTES; + long limit = inputBase + inputSize; + if (withinSafeZone) { + for (; address < limit; address += Long.BYTES) { + var block = maskWord(UNSAFE.getLong(address), limit - address); + long matchBits = matchBits(block); + if (matchBits != 0) { + return address + (Long.numberOfTrailingZeros(matchBits) >> 3) - nameStartAddress; + } + } + throw new RuntimeException("Semicolon not found"); + } + return addrOfSemicolonSafe(address, limit) - nameStartAddress; } - private long posOfSemicolonSimple(long offset) { - for (; offset < inputSize; offset++) { - if (UNSAFE.getByte(inputBase + offset) == SEMICOLON) { - return offset; + private static long addrOfSemicolonSafe(long address, long limit) { + for (; address < limit - Long.BYTES + 1; address += Long.BYTES) { + var block = UNSAFE.getLong(address); + long matchBits = matchBits(block); + if (matchBits != 0) { + return address + (Long.numberOfTrailingZeros(matchBits) >> 3); + } + } + for (; address < limit; address++) { + if (UNSAFE.getByte(address) == SEMICOLON) { + return address; } } throw new RuntimeException("Semicolon not found"); } + private static long matchBits(long word) { + long diff = word ^ BROADCAST_SEMICOLON; + return (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80; + } + // Copies the results from native memory to Java heap and puts them into the results array. private void exportResults() { var exportedStats = new ArrayList(10_000);