From 2c432abb964db2a2556a06ff00e12221e4f58995 Mon Sep 17 00:00:00 2001 From: Jaromir Hamala Date: Tue, 23 Jan 2024 18:29:22 +0100 Subject: [PATCH] jerrinot's improvement - fast-path for short keys (#545) * fast-path for keys<16 bytes * fix off by one error the mask is wrong for he 2nd word when len == 16 * less chunks per thread seems like compact code wins. on my test box anyway. --- .../onebrc/CalculateAverage_jerrinot.java | 441 +++++++++++------- 1 file changed, 277 insertions(+), 164 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java index 13e48ae..2492c0f 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java @@ -38,7 +38,7 @@ public class CalculateAverage_jerrinot { // todo: with hyper-threading enable we would be better of with availableProcessors / 2; // todo: validate the testing env. params. private static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors(); - // private static final int THREAD_COUNT = 1; + // private static final int THREAD_COUNT = 4; private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL; @@ -61,7 +61,7 @@ public class CalculateAverage_jerrinot { final File file = new File(MEASUREMENTS_TXT); final long length = file.length(); // final int chunkCount = Runtime.getRuntime().availableProcessors(); - int chunkPerThread = 4; + int chunkPerThread = 3; final int chunkCount = THREAD_COUNT * chunkPerThread; final var chunkStartOffsets = new long[chunkCount + 1]; try (var raf = new RandomAccessFile(file, "r")) { @@ -88,10 +88,8 @@ public class CalculateAverage_jerrinot { long endB = chunkStartOffsets[i * chunkPerThread + 2]; long startC = chunkStartOffsets[i * chunkPerThread + 2]; long endC = chunkStartOffsets[i * chunkPerThread + 3]; - long startD = chunkStartOffsets[i * chunkPerThread + 3]; - long endD = chunkStartOffsets[i * chunkPerThread + 4]; - Processor processor = new Processor(startA, endA, startB, endB, startC, endC, startD, endD); + Processor processor = new Processor(startA, endA, startB, endB, startC, endC); processors[i] = processor; Thread thread = new Thread(processor); threads[i] = thread; @@ -105,9 +103,7 @@ public class CalculateAverage_jerrinot { long endB = chunkStartOffsets[ownIndex * chunkPerThread + 2]; long startC = chunkStartOffsets[ownIndex * chunkPerThread + 2]; long endC = chunkStartOffsets[ownIndex * chunkPerThread + 3]; - long startD = chunkStartOffsets[ownIndex * chunkPerThread + 3]; - long endD = chunkStartOffsets[ownIndex * chunkPerThread + 4]; - Processor processor = new Processor(startA, endA, startB, endB, startC, endC, startD, endD); + Processor processor = new Processor(startA, endA, startB, endB, startC, endC); processor.run(); var accumulator = new TreeMap(); @@ -119,29 +115,33 @@ public class CalculateAverage_jerrinot { processors[i].accumulateStatus(accumulator); } - var sb = new StringBuilder(); - boolean first = true; - for (Map.Entry statsEntry : accumulator.entrySet()) { - if (first) { - sb.append("{"); - first = false; - } - else { - sb.append(", "); - } - var value = statsEntry.getValue(); - var name = statsEntry.getKey(); - int min = value.min; - int max = value.max; - int count = value.count; - long sum2 = value.sum; - sb.append(String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum2 / count) / 10.0, max / 10.0)); - } - System.out.print(sb); - System.out.println('}'); + printResults(accumulator); } } + private static void printResults(TreeMap accumulator) { + var sb = new StringBuilder(10000); + boolean first = true; + for (Map.Entry statsEntry : accumulator.entrySet()) { + if (first) { + sb.append("{"); + first = false; + } + else { + sb.append(", "); + } + var value = statsEntry.getValue(); + var name = statsEntry.getKey(); + int min = value.min; + int max = value.max; + int count = value.count; + long sum2 = value.sum; + sb.append(String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum2 / count) / 10.0, max / 10.0)); + } + sb.append('}'); + System.out.println(sb); + } + public static int ceilPow2(int i) { i--; i |= i >> 1; @@ -154,51 +154,65 @@ public class CalculateAverage_jerrinot { private static class Processor implements Runnable { private static final int MAX_UNIQUE_KEYS = 10000; - private static final int MAP_SLOT_COUNT = ceilPow2(MAX_UNIQUE_KEYS); + private static final int MAPS_SLOT_COUNT = ceilPow2(MAX_UNIQUE_KEYS); private static final int STATION_MAX_NAME_BYTES = 104; - private static final long COUNT_OFFSET = 0; - private static final long MIN_OFFSET = 4; - private static final long MAX_OFFSET = 8; - private static final long SUM_OFFSET = 12; - private static final long LEN_OFFSET = 20; - private static final long NAME_OFFSET = 24; + private static final long MAP_COUNT_OFFSET = 0; + private static final long MAP_MIN_OFFSET = 4; + private static final long MAP_MAX_OFFSET = 8; + private static final long MAP_SUM_OFFSET = 12; + private static final long MAP_LEN_OFFSET = 20; + private static final long SLOW_MAP_NAME_OFFSET = 24; - private static final int MAP_ENTRY_SIZE_BYTES = Integer.BYTES // count // 0 + // private int longestChain = 0; + + private static final int SLOW_MAP_ENTRY_SIZE_BYTES = Integer.BYTES // count // 0 + Integer.BYTES // min // +4 + Integer.BYTES // max // +8 + Long.BYTES // sum // +12 + Integer.BYTES // station name len // +20 + Long.BYTES; // station name ptr // 24 - private static final int MAP_SIZE_BYTES = MAP_SLOT_COUNT * MAP_ENTRY_SIZE_BYTES; - private static final int MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES; - private static final long MAP_MASK = MAP_SLOT_COUNT - 1; + private static final long FAST_MAP_NAME_PART1 = 24; + private static final long FAST_MAP_NAME_PART2 = 32; - private final long map; - private long currentNamesPtr; - private final long namesHi; + private static final int FAST_MAP_ENTRY_SIZE_BYTES = Integer.BYTES // count // 0 + + Integer.BYTES // min // +4 + + Integer.BYTES // max // +8 + + Long.BYTES // sum // +12 + + Integer.BYTES // station name len // +20 + + Long.BYTES // station name part 1 // 24 + + Long.BYTES; // station name part 2 // 32 + + private static final int SLOW_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * SLOW_MAP_ENTRY_SIZE_BYTES; + private static final int FAST_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * FAST_MAP_ENTRY_SIZE_BYTES; + private static final int SLOW_MAP_MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES; + private static final long MAP_MASK = MAPS_SLOT_COUNT - 1; + + private long slowMap; + private long slowMapNamesPtr; + private long slowMapNamesLo; + private long fastMap; private long cursorA; private long endA; private long cursorB; private long endB; private long cursorC; private long endC; - private long cursorD; - private long endD; + private HashMap stats = new HashMap<>(1000); // private long maxClusterLen; // credit: merykitty private long parseAndStoreTemperature(long startCursor, long baseEntryPtr, long word) { // long word = UNSAFE.getLong(startCursor); - long countPtr = baseEntryPtr + COUNT_OFFSET; + long countPtr = baseEntryPtr + MAP_COUNT_OFFSET; int cnt = UNSAFE.getInt(countPtr); UNSAFE.putInt(countPtr, cnt + 1); - long minPtr = baseEntryPtr + MIN_OFFSET; - long maxPtr = baseEntryPtr + MAX_OFFSET; - long sumPtr = baseEntryPtr + SUM_OFFSET; + long minPtr = baseEntryPtr + MAP_MIN_OFFSET; + long maxPtr = baseEntryPtr + MAP_MAX_OFFSET; + long sumPtr = baseEntryPtr + MAP_SUM_OFFSET; int min = UNSAFE.getInt(minPtr); int max = UNSAFE.getInt(maxPtr); @@ -232,77 +246,103 @@ public class CalculateAverage_jerrinot { // todo: immutability cost us in allocations, but that's probably peanuts in the grand scheme of things. still worth checking // maybe JVM trusting Final in Records offsets it ..a test is needed record StationStats(int min, int max, int count, long sum) { + StationStats mergeWith(StationStats other) { + return new StationStats(Math.min(min, other.min), Math.max(max, other.max), count + other.count, sum + other.sum); + } } void accumulateStatus(TreeMap accumulator) { - for (long baseAddress = map; baseAddress < map + MAP_SIZE_BYTES; baseAddress += MAP_ENTRY_SIZE_BYTES) { - long len = UNSAFE.getInt(baseAddress + LEN_OFFSET); + for (Map.Entry entry : stats.entrySet()) { + String name = entry.getKey(); + StationStats localStats = entry.getValue(); + + StationStats globalStats = accumulator.get(name); + if (globalStats == null) { + accumulator.put(name, localStats); + } + else { + accumulator.put(name, globalStats.mergeWith(localStats)); + } + } + } + + Processor(long startA, long endA, long startB, long endB, long startC, long endC) { + this.cursorA = startA; + this.cursorB = startB; + this.cursorC = startC; + this.endA = endA; + this.endB = endB; + this.endC = endC; + } + + private void doTail() { + doOne(cursorA, endA); + doOne(cursorB, endB); + doOne(cursorC, endC); + + transferToHeap(); + UNSAFE.freeMemory(fastMap); + UNSAFE.freeMemory(slowMap); + UNSAFE.freeMemory(slowMapNamesLo); + } + + private void transferToHeap() { + for (long baseAddress = slowMap; baseAddress < slowMap + SLOW_MAP_SIZE_BYTES; baseAddress += SLOW_MAP_ENTRY_SIZE_BYTES) { + long len = UNSAFE.getInt(baseAddress + MAP_LEN_OFFSET); if (len == 0) { continue; } byte[] nameArr = new byte[(int) len]; - long baseNameAddr = UNSAFE.getLong(baseAddress + NAME_OFFSET); + long baseNameAddr = UNSAFE.getLong(baseAddress + SLOW_MAP_NAME_OFFSET); for (int i = 0; i < len; i++) { nameArr[i] = UNSAFE.getByte(baseNameAddr + i); } String name = new String(nameArr); - int min = UNSAFE.getInt(baseAddress + MIN_OFFSET); - int max = UNSAFE.getInt(baseAddress + MAX_OFFSET); - int count = UNSAFE.getInt(baseAddress + COUNT_OFFSET); - long sum = UNSAFE.getLong(baseAddress + SUM_OFFSET); + int min = UNSAFE.getInt(baseAddress + MAP_MIN_OFFSET); + int max = UNSAFE.getInt(baseAddress + MAP_MAX_OFFSET); + int count = UNSAFE.getInt(baseAddress + MAP_COUNT_OFFSET); + long sum = UNSAFE.getLong(baseAddress + MAP_SUM_OFFSET); - var v = accumulator.get(name); + stats.put(name, new StationStats(min, max, count, sum)); + } + + for (long baseAddress = fastMap; baseAddress < fastMap + FAST_MAP_SIZE_BYTES; baseAddress += FAST_MAP_ENTRY_SIZE_BYTES) { + long len = UNSAFE.getInt(baseAddress + MAP_LEN_OFFSET); + if (len == 0) { + continue; + } + byte[] nameArr = new byte[(int) len]; + long baseNameAddr = baseAddress + FAST_MAP_NAME_PART1; + for (int i = 0; i < len; i++) { + nameArr[i] = UNSAFE.getByte(baseNameAddr + i); + } + String name = new String(nameArr); + int min = UNSAFE.getInt(baseAddress + MAP_MIN_OFFSET); + int max = UNSAFE.getInt(baseAddress + MAP_MAX_OFFSET); + int count = UNSAFE.getInt(baseAddress + MAP_COUNT_OFFSET); + long sum = UNSAFE.getLong(baseAddress + MAP_SUM_OFFSET); + + var v = stats.get(name); if (v == null) { - accumulator.put(name, new StationStats(min, max, count, sum)); + stats.put(name, new StationStats(min, max, count, sum)); } else { - accumulator.put(name, new StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum)); + stats.put(name, new StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum)); } } } - Processor(long startA, long endA, long startB, long endB, long startC, long endC, long startD, long endD) { - this.cursorA = startA; - this.cursorB = startB; - this.cursorC = startC; - this.cursorD = startD; - this.endA = endA; - this.endB = endB; - this.endC = endC; - this.endD = endD; - this.map = UNSAFE.allocateMemory(MAP_SIZE_BYTES); - this.currentNamesPtr = UNSAFE.allocateMemory(MAP_NAMES_BYTES); - this.namesHi = currentNamesPtr + MAP_NAMES_BYTES; - - int i; - for (i = 0; i < MAP_SIZE_BYTES; i += 8) { - UNSAFE.putLong(map + i, 0); - } - for (i = i - 8; i < MAP_SIZE_BYTES; i++) { - UNSAFE.putByte(map + i, (byte) 0); - } - UNSAFE.setMemory(currentNamesPtr, MAP_NAMES_BYTES, (byte) 0); - } - - private void doTail() { - // todo: we would be probably better of without all that code dup. ("compilers hates him!") - // System.out.println("done ILP"); - doOne(cursorA, endA); - // System.out.println("done A"); - doOne(cursorB, endB); - // System.out.println("done B"); - doOne(cursorC, endC); - // System.out.println("done C"); - doOne(cursorD, endD); - // System.out.println("done D"); - } - private void doOne(long cursor, long endA) { while (cursor < endA) { long start = cursor; long currentWord = UNSAFE.getLong(cursor); long mask = getDelimiterMask(currentWord); - long maskedFirstWord = currentWord & ((mask - 1) ^ mask) >>> 8; + long firstWordMask = ((mask - 1) ^ mask) >>> 8; + final long isMaskZeroA = ((mask | -mask) >>> 63) ^ 1; + long ext = -isMaskZeroA & 0xFF00_0000_0000_0000L; + firstWordMask |= ext; + + long maskedFirstWord = currentWord & firstWordMask; long hash = hash(maskedFirstWord); while (mask == 0) { cursor += 8; @@ -312,7 +352,9 @@ public class CalculateAverage_jerrinot { final int delimiterByte = Long.numberOfTrailingZeros(mask); final long semicolon = cursor + (delimiterByte >> 3); final long maskedWord = currentWord & ((mask - 1) ^ mask) >>> 8; - long baseEntryPtr = getOrCreateEntryBaseOffset(semicolon, start, (int) hash, maskedWord); + + long len = semicolon - start; + long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, (int) hash, maskedWord); long temperatureWord = UNSAFE.getLong(semicolon + 1); cursor = parseAndStoreTemperature(semicolon + 1, baseEntryPtr, temperatureWord); } @@ -331,133 +373,204 @@ public class CalculateAverage_jerrinot { @Override public void run() { - while (cursorA < endA && cursorB < endB && cursorC < endC && cursorD < endD) { + this.slowMap = UNSAFE.allocateMemory(SLOW_MAP_SIZE_BYTES); + this.slowMapNamesPtr = UNSAFE.allocateMemory(SLOW_MAP_MAP_NAMES_BYTES); + this.slowMapNamesLo = slowMapNamesPtr; + this.fastMap = UNSAFE.allocateMemory(FAST_MAP_SIZE_BYTES); + UNSAFE.setMemory(slowMap, SLOW_MAP_SIZE_BYTES, (byte) 0); + UNSAFE.setMemory(fastMap, FAST_MAP_SIZE_BYTES, (byte) 0); + UNSAFE.setMemory(slowMapNamesPtr, SLOW_MAP_MAP_NAMES_BYTES, (byte) 0); + + while (cursorA < endA && cursorB < endB && cursorC < endC) { long startA = cursorA; long startB = cursorB; long startC = cursorC; - long startD = cursorD; long currentWordA = UNSAFE.getLong(startA); long currentWordB = UNSAFE.getLong(startB); long currentWordC = UNSAFE.getLong(startC); - long currentWordD = UNSAFE.getLong(startD); - // credits for the hashing idea: mtopolnik long maskA = getDelimiterMask(currentWordA); - long maskedFirstWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8; + long maskB = getDelimiterMask(currentWordB); + long maskC = getDelimiterMask(currentWordC); + + long firstWordMaskA = (maskA ^ (maskA - 1)) >>> 8; + long firstWordMaskB = (maskB ^ (maskB - 1)) >>> 8; + long firstWordMaskC = (maskC ^ (maskC - 1)) >>> 8; + + final long isMaskZeroA = ((maskA | -maskA) >>> 63) ^ 1; + final long isMaskZeroB = ((maskB | -maskB) >>> 63) ^ 1; + final long isMaskZeroC = ((maskC | -maskC) >>> 63) ^ 1; + + long extA = -isMaskZeroA & 0xFF00_0000_0000_0000L; + long extB = -isMaskZeroB & 0xFF00_0000_0000_0000L; + long extC = -isMaskZeroC & 0xFF00_0000_0000_0000L; + + firstWordMaskA |= extA; + firstWordMaskB |= extB; + firstWordMaskC |= extC; + + long maskedFirstWordA = currentWordA & firstWordMaskA; + long maskedFirstWordB = currentWordB & firstWordMaskB; + long maskedFirstWordC = currentWordC & firstWordMaskC; + + // assertMasks(isMaskZeroA, maskA); + long hashA = hash(maskedFirstWordA); + long hashB = hash(maskedFirstWordB); + long hashC = hash(maskedFirstWordC); + + cursorA += isMaskZeroA * 8; + cursorB += isMaskZeroB * 8; + cursorC += isMaskZeroC * 8; + + currentWordA = UNSAFE.getLong(cursorA); + currentWordB = UNSAFE.getLong(cursorB); + currentWordC = UNSAFE.getLong(cursorC); + + maskA = getDelimiterMask(currentWordA); while (maskA == 0) { cursorA += 8; currentWordA = UNSAFE.getLong(cursorA); maskA = getDelimiterMask(currentWordA); } - final int delimiterByteA = Long.numberOfTrailingZeros(maskA); - final long semicolonA = cursorA + (delimiterByteA >> 3); - long temperatureWordA = UNSAFE.getLong(semicolonA + 1); - final long maskedWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8; - - long maskB = getDelimiterMask(currentWordB); - long maskedFirstWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8; - long hashB = hash(maskedFirstWordB); + maskB = getDelimiterMask(currentWordB); while (maskB == 0) { cursorB += 8; currentWordB = UNSAFE.getLong(cursorB); maskB = getDelimiterMask(currentWordB); } - final int delimiterByteB = Long.numberOfTrailingZeros(maskB); - final long semicolonB = cursorB + (delimiterByteB >> 3); - long temperatureWordB = UNSAFE.getLong(semicolonB + 1); - final long maskedWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8; - - long maskC = getDelimiterMask(currentWordC); - long maskedFirstWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8; - long hashC = hash(maskedFirstWordC); + maskC = getDelimiterMask(currentWordC); while (maskC == 0) { cursorC += 8; currentWordC = UNSAFE.getLong(cursorC); maskC = getDelimiterMask(currentWordC); } + + final int delimiterByteA = Long.numberOfTrailingZeros(maskA); + final int delimiterByteB = Long.numberOfTrailingZeros(maskB); final int delimiterByteC = Long.numberOfTrailingZeros(maskC); + + final long semicolonA = cursorA + (delimiterByteA >> 3); + final long semicolonB = cursorB + (delimiterByteB >> 3); final long semicolonC = cursorC + (delimiterByteC >> 3); - long temperatureWordC = UNSAFE.getLong(semicolonC + 1); + + long digitStartA = semicolonA + 1; + long digitStartB = semicolonB + 1; + long digitStartC = semicolonC + 1; + long temperatureWordA = UNSAFE.getLong(digitStartA); + long temperatureWordB = UNSAFE.getLong(digitStartB); + long temperatureWordC = UNSAFE.getLong(digitStartC); + + final long maskedWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8; + final long maskedWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8; final long maskedWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8; - long maskD = getDelimiterMask(currentWordD); - long maskedFirstWordD = currentWordD & ((maskD - 1) ^ maskD) >>> 8; - long hashD = hash(maskedFirstWordD); - while (maskD == 0) { - cursorD += 8; - currentWordD = UNSAFE.getLong(cursorD); - maskD = getDelimiterMask(currentWordD); + long lenA = semicolonA - startA; + long lenB = semicolonB - startB; + long lenC = semicolonC - startC; + + long baseEntryPtrA; + if (lenA > 15) { + baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, (int) hashA, maskedWordA); + } + else { + baseEntryPtrA = getOrCreateEntryBaseOffsetFast(lenA, (int) hashA, maskedWordA, maskedFirstWordA); } - final int delimiterByteD = Long.numberOfTrailingZeros(maskD); - final long semicolonD = cursorD + (delimiterByteD >> 3); - long temperatureWordD = UNSAFE.getLong(semicolonD + 1); - final long maskedWordD = currentWordD & ((maskD - 1) ^ maskD) >>> 8; - long baseEntryPtrA = getOrCreateEntryBaseOffset(semicolonA, startA, (int) hashA, maskedWordA); - long baseEntryPtrB = getOrCreateEntryBaseOffset(semicolonB, startB, (int) hashB, maskedWordB); - long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, (int) hashC, maskedWordC); - long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, (int) hashD, maskedWordD); + long baseEntryPtrB; + if (lenB > 15) { + baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, (int) hashB, maskedWordB); + } + else { + baseEntryPtrB = getOrCreateEntryBaseOffsetFast(lenB, (int) hashB, maskedWordB, maskedFirstWordB); + } - cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA, temperatureWordA); - cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB, temperatureWordB); - cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC, temperatureWordC); - cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD, temperatureWordD); + long baseEntryPtrC; + if (lenC > 15) { + baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, (int) hashC, maskedWordC); + } + else { + baseEntryPtrC = getOrCreateEntryBaseOffsetFast(lenC, (int) hashC, maskedWordC, maskedFirstWordC); + } + + cursorA = parseAndStoreTemperature(digitStartA, baseEntryPtrA, temperatureWordA); + cursorB = parseAndStoreTemperature(digitStartB, baseEntryPtrB, temperatureWordB); + cursorC = parseAndStoreTemperature(digitStartC, baseEntryPtrC, temperatureWordC); } doTail(); + // System.out.println("Longest chain: " + longestChain); } - private long getOrCreateEntryBaseOffset(long semicolonPtr, long startPtr, int hash, long maskedWord) { - long lenLong = semicolonPtr - startPtr; + private long getOrCreateEntryBaseOffsetFast(long lenLong, int hash, long maskedLastWord, long maskedFirstWord) { int lenA = (int) lenLong; - long mapIndexA = hash & MAP_MASK; for (;;) { - long basePtr = mapIndexA * MAP_ENTRY_SIZE_BYTES + map; - long lenPtr = basePtr + LEN_OFFSET; - long namePtr = basePtr + NAME_OFFSET; + long basePtr = mapIndexA * FAST_MAP_ENTRY_SIZE_BYTES + fastMap; + long lenPtr = basePtr + MAP_LEN_OFFSET; int len = UNSAFE.getInt(lenPtr); if (len == lenA) { - namePtr = UNSAFE.getLong(basePtr + NAME_OFFSET); - if (nameMatch(startPtr, maskedWord, namePtr, lenLong)) { + long namePart1 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART1); + long namePart2 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART2); + if (namePart1 == maskedFirstWord && namePart2 == maskedLastWord) { return basePtr; } } else if (len == 0) { - // todo: uncommon branch maybe? - // empty slot - UNSAFE.putLong(namePtr, currentNamesPtr); UNSAFE.putInt(lenPtr, lenA); // todo: this could be a single putLong() - UNSAFE.putInt(basePtr + MAX_OFFSET, Integer.MIN_VALUE); - UNSAFE.putInt(basePtr + MIN_OFFSET, Integer.MAX_VALUE); - UNSAFE.copyMemory(startPtr, currentNamesPtr, lenA); - long consume = (lenLong & ~7L) + 8; - currentNamesPtr += consume; - assert currentNamesPtr <= namesHi; + UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE); + UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE); + UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART1, maskedFirstWord); + UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART2, maskedLastWord); return basePtr; } mapIndexA = ++mapIndexA & MAP_MASK; } } - private static boolean nameMatch(long startA, long maskedWordA, long namePtr, long len) { - // long namePtr = basePtr + NAME_OFFSET; - long fullLen = len & ~7L; - long offset; + private long getOrCreateEntryBaseOffsetSlow(long lenLong, long startPtr, int hash, long maskedLastWord) { + long fullLen = lenLong & ~7L; + int lenA = (int) lenLong; + long mapIndexA = hash & MAP_MASK; + for (;;) { + long basePtr = mapIndexA * SLOW_MAP_ENTRY_SIZE_BYTES + slowMap; + long lenPtr = basePtr + MAP_LEN_OFFSET; + long namePtr = basePtr + SLOW_MAP_NAME_OFFSET; + int len = UNSAFE.getInt(lenPtr); + if (len == lenA) { + namePtr = UNSAFE.getLong(basePtr + SLOW_MAP_NAME_OFFSET); + if (nameMatch(startPtr, maskedLastWord, namePtr, fullLen)) { + return basePtr; + } + } + else if (len == 0) { + UNSAFE.putLong(namePtr, slowMapNamesPtr); + UNSAFE.putInt(lenPtr, lenA); + UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE); + UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE); + UNSAFE.copyMemory(startPtr, slowMapNamesPtr, lenA); + long alignedLen = (lenLong & ~7L) + 8; + slowMapNamesPtr += alignedLen; + return basePtr; + } + mapIndexA = ++mapIndexA & MAP_MASK; + } + } - // todo: this is worth exploring further. - // @mtopolnik has an interesting algo with 2 unconditioned long loads: this is sufficient - // for majority of names. so we would be left with just a single branch which is almost never taken? + private static boolean nameMatch(long start, long maskedLastWord, long namePtr, long fullLen) { + return nameMatchSlow(start, namePtr, fullLen, maskedLastWord); + } + + private static boolean nameMatchSlow(long start, long namePtr, long fullLen, long maskedLastWord) { + long offset; for (offset = 0; offset < fullLen; offset += 8) { - if (UNSAFE.getLong(startA + offset) != UNSAFE.getLong(namePtr + offset)) { + if (UNSAFE.getLong(start + offset) != UNSAFE.getLong(namePtr + offset)) { return false; } } - long maskedWordInMap = UNSAFE.getLong(namePtr + fullLen); - return (maskedWordInMap == maskedWordA); + return (maskedWordInMap == maskedLastWord); } }