From 114ba76d20f946ac6421aff73cd69387b0cb15b7 Mon Sep 17 00:00:00 2001 From: Jaromir Hamala Date: Sat, 20 Jan 2024 20:06:31 +0100 Subject: [PATCH] jerrinot's improvement (#514) * refactoring * segregated heap for names also a different hashing function. turns out hashing just first word is good enough --- .../onebrc/CalculateAverage_jerrinot.java | 151 ++++++++---------- 1 file changed, 67 insertions(+), 84 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java index 5373cb0..13e48ae 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java @@ -38,7 +38,7 @@ public class CalculateAverage_jerrinot { // todo: with hyper-threading enable we would be better of with availableProcessors / 2; // todo: validate the testing env. params. private static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors(); - // private static final int THREAD_COUNT = 4; + // private static final int THREAD_COUNT = 1; private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL; @@ -153,8 +153,9 @@ public class CalculateAverage_jerrinot { } private static class Processor implements Runnable { - private static final int MAP_SLOT_COUNT = ceilPow2(10000); - private static final int STATION_MAX_NAME_BYTES = 120; + private static final int MAX_UNIQUE_KEYS = 10000; + private static final int MAP_SLOT_COUNT = ceilPow2(MAX_UNIQUE_KEYS); + private static final int STATION_MAX_NAME_BYTES = 104; private static final long COUNT_OFFSET = 0; private static final long MIN_OFFSET = 4; @@ -163,20 +164,20 @@ public class CalculateAverage_jerrinot { private static final long LEN_OFFSET = 20; private static final long NAME_OFFSET = 24; - private static final int MAP_ENTRY_SIZE_BYTES = +Integer.BYTES // count // 0 + private static final int MAP_ENTRY_SIZE_BYTES = Integer.BYTES // count // 0 + Integer.BYTES // min // +4 + Integer.BYTES // max // +8 + Long.BYTES // sum // +12 + Integer.BYTES // station name len // +20 - + STATION_MAX_NAME_BYTES; // +24 + + Long.BYTES; // station name ptr // 24 private static final int MAP_SIZE_BYTES = MAP_SLOT_COUNT * MAP_ENTRY_SIZE_BYTES; + private static final int MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES; private static final long MAP_MASK = MAP_SLOT_COUNT - 1; - // todo: some fields could probably be converted to locals - private final long map; - + private long currentNamesPtr; + private final long namesHi; private long cursorA; private long endA; private long cursorB; @@ -240,7 +241,7 @@ public class CalculateAverage_jerrinot { continue; } byte[] nameArr = new byte[(int) len]; - long baseNameAddr = baseAddress + NAME_OFFSET; + long baseNameAddr = UNSAFE.getLong(baseAddress + NAME_OFFSET); for (int i = 0; i < len; i++) { nameArr[i] = UNSAFE.getByte(baseNameAddr + i); } @@ -270,6 +271,8 @@ public class CalculateAverage_jerrinot { this.endC = endC; this.endD = endD; this.map = UNSAFE.allocateMemory(MAP_SIZE_BYTES); + this.currentNamesPtr = UNSAFE.allocateMemory(MAP_NAMES_BYTES); + this.namesHi = currentNamesPtr + MAP_NAMES_BYTES; int i; for (i = 0; i < MAP_SIZE_BYTES; i += 8) { @@ -278,6 +281,7 @@ public class CalculateAverage_jerrinot { for (i = i - 8; i < MAP_SIZE_BYTES; i++) { UNSAFE.putByte(map + i, (byte) 0); } + UNSAFE.setMemory(currentNamesPtr, MAP_NAMES_BYTES, (byte) 0); } private void doTail() { @@ -293,58 +297,56 @@ public class CalculateAverage_jerrinot { // System.out.println("done D"); } - private void doOne(long cursorA, long endA) { - while (cursorA < endA) { - long startA = cursorA; - long delimiterWordA = UNSAFE.getLong(cursorA); - long hashA = 0; - long maskA = getDelimiterMask(delimiterWordA); - while (maskA == 0) { - hashA ^= delimiterWordA; - cursorA += 8; - delimiterWordA = UNSAFE.getLong(cursorA); - maskA = getDelimiterMask(delimiterWordA); + private void doOne(long cursor, long endA) { + while (cursor < endA) { + long start = cursor; + long currentWord = UNSAFE.getLong(cursor); + long mask = getDelimiterMask(currentWord); + long maskedFirstWord = currentWord & ((mask - 1) ^ mask) >>> 8; + long hash = hash(maskedFirstWord); + while (mask == 0) { + cursor += 8; + currentWord = UNSAFE.getLong(cursor); + mask = getDelimiterMask(currentWord); } - final int delimiterByteA = Long.numberOfTrailingZeros(maskA); - final long semicolonA = cursorA + (delimiterByteA >> 3); - final long maskedWordA = delimiterWordA & ((maskA - 1) ^ maskA) >>> 8; - hashA ^= maskedWordA; - int intHashA = (int) (hashA ^ (hashA >> 32)); - intHashA = intHashA ^ (intHashA >> 17); - - long baseEntryPtrA = getOrCreateEntryBaseOffset(semicolonA, startA, intHashA, maskedWordA); - long temperatureWordA = UNSAFE.getLong(semicolonA + 1); - cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA, temperatureWordA); + final int delimiterByte = Long.numberOfTrailingZeros(mask); + final long semicolon = cursor + (delimiterByte >> 3); + final long maskedWord = currentWord & ((mask - 1) ^ mask) >>> 8; + long baseEntryPtr = getOrCreateEntryBaseOffset(semicolon, start, (int) hash, maskedWord); + long temperatureWord = UNSAFE.getLong(semicolon + 1); + cursor = parseAndStoreTemperature(semicolon + 1, baseEntryPtr, temperatureWord); } } + private static long hash(long word1) { + // credit: mtopolnik + long seed = 0x51_7c_c1_b7_27_22_0a_95L; + int rotDist = 17; + + long hash = word1; + hash *= seed; + hash = Long.rotateLeft(hash, rotDist); + return hash; + } + @Override public void run() { while (cursorA < endA && cursorB < endB && cursorC < endC && cursorD < endD) { - // todo: experiment with different inter-leaving long startA = cursorA; long startB = cursorB; long startC = cursorC; long startD = cursorD; long currentWordA = UNSAFE.getLong(startA); - // long delimiterWordA2 = UNSAFE.getLong(startA + 8); long currentWordB = UNSAFE.getLong(startB); - // long delimiterWordB2 = UNSAFE.getLong(startB + 8); long currentWordC = UNSAFE.getLong(startC); - // long delimiterWordCa = UNSAFE.getLong(startC + 8); long currentWordD = UNSAFE.getLong(startD); - // long delimiterWordD2 = UNSAFE.getLong(startD + 8); - long hashA = 0; - long hashB = 0; - long hashC = 0; - long hashD = 0; - - // credits for the hashing idea: royvanrijn + // credits for the hashing idea: mtopolnik long maskA = getDelimiterMask(currentWordA); + long maskedFirstWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8; + long hashA = hash(maskedFirstWordA); while (maskA == 0) { - hashA ^= currentWordA; cursorA += 8; currentWordA = UNSAFE.getLong(cursorA); maskA = getDelimiterMask(currentWordA); @@ -353,13 +355,11 @@ public class CalculateAverage_jerrinot { final long semicolonA = cursorA + (delimiterByteA >> 3); long temperatureWordA = UNSAFE.getLong(semicolonA + 1); final long maskedWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8; - hashA ^= maskedWordA; - int intHashA = (int) (hashA ^ (hashA >> 32)); - intHashA = intHashA ^ (intHashA >> 17); long maskB = getDelimiterMask(currentWordB); + long maskedFirstWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8; + long hashB = hash(maskedFirstWordB); while (maskB == 0) { - hashB ^= currentWordB; cursorB += 8; currentWordB = UNSAFE.getLong(cursorB); maskB = getDelimiterMask(currentWordB); @@ -368,13 +368,11 @@ public class CalculateAverage_jerrinot { final long semicolonB = cursorB + (delimiterByteB >> 3); long temperatureWordB = UNSAFE.getLong(semicolonB + 1); final long maskedWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8; - hashB ^= maskedWordB; - int intHashB = (int) (hashB ^ (hashB >> 32)); - intHashB = intHashB ^ (intHashB >> 17); long maskC = getDelimiterMask(currentWordC); + long maskedFirstWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8; + long hashC = hash(maskedFirstWordC); while (maskC == 0) { - hashC ^= currentWordC; cursorC += 8; currentWordC = UNSAFE.getLong(cursorC); maskC = getDelimiterMask(currentWordC); @@ -383,13 +381,11 @@ public class CalculateAverage_jerrinot { final long semicolonC = cursorC + (delimiterByteC >> 3); long temperatureWordC = UNSAFE.getLong(semicolonC + 1); final long maskedWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8; - hashC ^= maskedWordC; - int intHashC = (int) (hashC ^ (hashC >> 32)); - intHashC = intHashC ^ (intHashC >> 17); long maskD = getDelimiterMask(currentWordD); + long maskedFirstWordD = currentWordD & ((maskD - 1) ^ maskD) >>> 8; + long hashD = hash(maskedFirstWordD); while (maskD == 0) { - hashD ^= currentWordD; cursorD += 8; currentWordD = UNSAFE.getLong(cursorD); maskD = getDelimiterMask(currentWordD); @@ -398,14 +394,11 @@ public class CalculateAverage_jerrinot { final long semicolonD = cursorD + (delimiterByteD >> 3); long temperatureWordD = UNSAFE.getLong(semicolonD + 1); final long maskedWordD = currentWordD & ((maskD - 1) ^ maskD) >>> 8; - hashD ^= maskedWordD; - int intHashD = (int) (hashD ^ (hashD >> 32)); - intHashD = intHashD ^ (intHashD >> 17); - long baseEntryPtrA = getOrCreateEntryBaseOffset(semicolonA, startA, intHashA, maskedWordA); - long baseEntryPtrB = getOrCreateEntryBaseOffset(semicolonB, startB, intHashB, maskedWordB); - long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, intHashC, maskedWordC); - long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, intHashD, maskedWordD); + long baseEntryPtrA = getOrCreateEntryBaseOffset(semicolonA, startA, (int) hashA, maskedWordA); + long baseEntryPtrB = getOrCreateEntryBaseOffset(semicolonB, startB, (int) hashB, maskedWordB); + long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, (int) hashC, maskedWordC); + long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, (int) hashD, maskedWordD); cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA, temperatureWordA); cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB, temperatureWordB); @@ -415,52 +408,42 @@ public class CalculateAverage_jerrinot { doTail(); } - private long getOrCreateEntryBaseOffset(long semicolonA, long startA, int intHashA, long maskedWordA) { - // hashSet.add(intHashA); - long lenLong = semicolonA - startA; + private long getOrCreateEntryBaseOffset(long semicolonPtr, long startPtr, int hash, long maskedWord) { + long lenLong = semicolonPtr - startPtr; int lenA = (int) lenLong; - // assert lenA != 0; - // byte[] nameArr = new byte[lenA]; - // for (int i = 0; i < lenA; i++) { - // nameArr[i] = UNSAFE.getByte(startA + i); - // } - // String nameStr = new String(nameArr); - // Integer oldHash = nameToHash.put(nameStr, intHashA); - // assert oldHash == null || oldHash == intHashA : "name: " + nameStr + ", old hash = " + oldHash + ", new hash = " + intHashA; - - long mapIndexA = intHashA & MAP_MASK; - // long clusterLen = 0; + long mapIndexA = hash & MAP_MASK; for (;;) { long basePtr = mapIndexA * MAP_ENTRY_SIZE_BYTES + map; long lenPtr = basePtr + LEN_OFFSET; + long namePtr = basePtr + NAME_OFFSET; int len = UNSAFE.getInt(lenPtr); if (len == lenA) { - if (nameMatch(startA, maskedWordA, basePtr, lenLong)) { - // if (clusterLen > maxClusterLen) { - // maxClusterLen = clusterLen; - // System.out.println("max cluster len: " + clusterLen); - // } + namePtr = UNSAFE.getLong(basePtr + NAME_OFFSET); + if (nameMatch(startPtr, maskedWord, namePtr, lenLong)) { return basePtr; } } else if (len == 0) { // todo: uncommon branch maybe? // empty slot - UNSAFE.copyMemory(semicolonA - lenA, basePtr + NAME_OFFSET, lenA); + UNSAFE.putLong(namePtr, currentNamesPtr); UNSAFE.putInt(lenPtr, lenA); // todo: this could be a single putLong() UNSAFE.putInt(basePtr + MAX_OFFSET, Integer.MIN_VALUE); UNSAFE.putInt(basePtr + MIN_OFFSET, Integer.MAX_VALUE); + UNSAFE.copyMemory(startPtr, currentNamesPtr, lenA); + long consume = (lenLong & ~7L) + 8; + currentNamesPtr += consume; + assert currentNamesPtr <= namesHi; return basePtr; } mapIndexA = ++mapIndexA & MAP_MASK; - // clusterLen++; } } - private static boolean nameMatch(long startA, long maskedWordA, long basePtr, long len) { - long namePtr = basePtr + NAME_OFFSET; + private static boolean nameMatch(long startA, long maskedWordA, long namePtr, long len) { + // long namePtr = basePtr + NAME_OFFSET; long fullLen = len & ~7L; long offset;