diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java index 6800456..1cd70e4 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java @@ -62,7 +62,10 @@ import sun.misc.Unsafe; * Unrolling scan-loop: 1200 ms (seems to help, perhaps even more on target machine) * Adding more readable reader: 1300 ms (scores got worse on target machine anyway) * - * I've ditched my M2 for an older x86-64 MacBook, this allows me to run `perf` and I'm trying to get lower numbers by trail and error. + * Using old x86 MacBook and perf: 3500 ms (different scoring) + * Decided to rewrite loop for 16 b: 3050 ms + * + * I have some instructions that could be removed, but faster with... * * Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai and many others for ideas. * @@ -80,7 +83,7 @@ public class CalculateAverage_royvanrijn { /** * Flyweight entry in a byte[], max 128 bytes. - * + *
* long: sum * int: min * int: max @@ -122,10 +125,12 @@ public class CalculateAverage_royvanrijn { } public static void main(String[] args) throws Exception { + if (args.length == 0 || !("--worker".equals(args[0]))) { spawnWorker(); return; } + // Calculate input segments. final FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); final long fileSize = fileChannel.size(); @@ -184,13 +189,15 @@ public class CalculateAverage_royvanrijn { System.out.close(); // close the stream to stop } - private static byte[] fillEntry(final byte[] entry, final long fromAddress, final int length, final int temp) { + private static byte[] fillEntry(final byte[] entry, final long fromAddress, final int entryLength, final int temp, final long readBuffer1, final long readBuffer2) { UNSAFE.putLong(entry, ENTRY_SUM, temp); UNSAFE.putInt(entry, ENTRY_MIN, temp); UNSAFE.putInt(entry, ENTRY_MAX, temp); UNSAFE.putInt(entry, ENTRY_COUNT, 1); - UNSAFE.putByte(entry, ENTRY_LENGTH, (byte) length); - UNSAFE.copyMemory(null, fromAddress, entry, ENTRY_NAME, length); + UNSAFE.putByte(entry, ENTRY_LENGTH, (byte) entryLength); + UNSAFE.copyMemory(null, fromAddress, entry, ENTRY_NAME, entryLength - 16); + UNSAFE.putLong(entry, ENTRY_NAME + entryLength - 16, readBuffer1); + UNSAFE.putLong(entry, ENTRY_NAME + entryLength - 8, readBuffer2); return entry; } @@ -219,16 +226,16 @@ public class CalculateAverage_royvanrijn { int count = UNSAFE.getInt(merge, ENTRY_COUNT); sum += UNSAFE.getLong(entry, ENTRY_SUM); - int entryMin = UNSAFE.getInt(entry, ENTRY_MIN); - int entryMax = UNSAFE.getInt(entry, ENTRY_MAX); count += UNSAFE.getInt(entry, ENTRY_COUNT); + int entryMin = UNSAFE.getInt(entry, ENTRY_MIN); + int entryMax = UNSAFE.getInt(entry, ENTRY_MAX); entryMin = Math.min(entryMin, mergeMin); entryMax = Math.max(entryMax, mergeMax); - - UNSAFE.putLong(entry, ENTRY_SUM, sum); UNSAFE.putInt(entry, ENTRY_MIN, entryMin); UNSAFE.putInt(entry, ENTRY_MAX, entryMax); + + UNSAFE.putLong(entry, ENTRY_SUM, sum); UNSAFE.putInt(entry, ENTRY_COUNT, count); return entry; } @@ -241,16 +248,16 @@ public class CalculateAverage_royvanrijn { UNSAFE.copyMemory(entry, ENTRY_NAME, name, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); // Create a new String with the existing byte[]: - return new String(name, StandardCharsets.UTF_8); + return new String(name, StandardCharsets.UTF_8).trim(); } private static String entryValuesToString(final byte[] entry) { - return round(UNSAFE.getInt(entry, ENTRY_MIN)) + return (round(UNSAFE.getInt(entry, ENTRY_MIN)) + "/" + round((1.0 * UNSAFE.getLong(entry, ENTRY_SUM)) / UNSAFE.getInt(entry, ENTRY_COUNT)) + "/" + - round(UNSAFE.getInt(entry, ENTRY_MAX)); + round(UNSAFE.getInt(entry, ENTRY_MAX))); } // Print a piece of memory: @@ -280,13 +287,12 @@ public class CalculateAverage_royvanrijn { private static final class Reader { private long ptr; - private long delimiterMask; - private long lastRead; - private long lastReadMinOne; + private long readBuffer1; + private long readBuffer2; private long hash; private long entryStart; - private long entryDelimiter; + private int entryLength; // in bytes rounded to nearest 16 private final long endAddress; @@ -309,6 +315,7 @@ public class CalculateAverage_royvanrijn { private void processStart() { hash = 0; entryStart = ptr; + entryLength = 0; } private boolean hasNext() { @@ -317,64 +324,77 @@ public class CalculateAverage_royvanrijn { private static final long DELIMITER_MASK = 0x3B3B3B3B3B3B3B3BL; - private boolean readFirst() { - lastRead = UNSAFE.getLong(ptr); - - final long match = lastRead ^ DELIMITER_MASK; - delimiterMask = (match - 0x0101010101010101L) & (~match & 0x8080808080808080L); - - return delimiterMask == 0; - } - private boolean readNext() { - lastReadMinOne = lastRead; - return readFirst(); - } - private void processName() { - hash ^= lastRead; - ptr += 8; + readBuffer1 = UNSAFE.getLong(ptr); + readBuffer2 = UNSAFE.getLong(ptr + 8); + + entryLength += 16; + + // Find delimiter and create mask for long1 + long comparisonResult1 = (readBuffer1 ^ DELIMITER_MASK); + long highBitMask1 = (comparisonResult1 - 0x0101010101010101L) & (~comparisonResult1 & 0x8080808080808080L); + + boolean noContent1 = highBitMask1 == 0; + long mask1 = noContent1 ? 0 : ~((highBitMask1 >>> 7) - 1); + int position1 = noContent1 ? -1 : Long.numberOfTrailingZeros(highBitMask1) >> 3; + + readBuffer1 &= ~mask1; + hash ^= readBuffer1; + + if (position1 != -1) { + hash ^= hash >> 32; + readBuffer2 = 0; + ptr += position1 + 1; + return false; + } + + // Repeat for long2 + long comparisonResult2 = (readBuffer2 ^ DELIMITER_MASK); + long highBitMask2 = (comparisonResult2 - 0x0101010101010101L) & (~comparisonResult2 & 0x8080808080808080L); + boolean noContent2 = highBitMask2 == 0; + long mask2 = noContent2 ? -1 : ((highBitMask2 >>> 7) - 1); + int position2 = noContent2 ? -1 : Long.numberOfTrailingZeros(highBitMask2) >> 3; + + mask2 = ~mask2; // also not necessary, but faster with? + // Apply masks + readBuffer2 &= ~mask2; + hash ^= readBuffer2; + + int delimiter = position2 == -1 ? -1 : position2 + 8; // not nnecessary, but faster? + + hash ^= hash >> 32; + + if (delimiter == -1) { + ptr += 16; + return true; + } + ptr += delimiter + 1; + return false; } private int processEndAndGetTemperature() { - processFinalBytes(); - finalizeHash(); - finalizeDelimiter(); - return readTemperature(); } - private void processFinalBytes() { - // Shift and read the last bytes: - lastRead &= ((delimiterMask >>> 7) - 1); - } - private void finalizeHash() { - // Finalize hash: - hash ^= lastRead; - hash ^= hash >> 32; hash ^= hash >> 17; // extra entropy } - private void finalizeDelimiter() { - // Found delimiter: - entryDelimiter = ptr + (Long.numberOfTrailingZeros(delimiterMask) >> 3); - } - private static final long DOT_BITS = 0x10101000; private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); // Awesome idea of merykitty: private int readTemperature() { // This is the number part: X.X, -X.X, XX.x or -XX.X - long numberBytes = UNSAFE.getLong(entryDelimiter + 1); + long numberBytes = UNSAFE.getLong(ptr); long invNumberBytes = ~numberBytes; int dotPosition = Long.numberOfTrailingZeros(invNumberBytes & DOT_BITS); // Update the pointer here, bit awkward, but we have all the data - ptr = entryDelimiter + (dotPosition >> 3) + 4; + ptr += (dotPosition >> 3) + 3; int min28 = (28 - dotPosition); // Calculates the sign @@ -388,57 +408,32 @@ public class CalculateAverage_royvanrijn { return (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick } - private boolean matchesEntryFull(final byte[] entry) { - int longs = (int) (entryDelimiter - entryStart) >> 3; + private boolean matches(final byte[] entry) { int step = 0; - for (int i = 0; i < longs - 2; i++) { - if (UNSAFE.getLong(entryStart + step) != UNSAFE.getLong(entry, ENTRY_NAME + step)) { + for (; step < entryLength - 16;) { + if (compare(null, entryStart + step, entry, ENTRY_NAME + step)) { + return false; + } + step += 8; + if (compare(null, entryStart + step, entry, ENTRY_NAME + step)) { return false; } step += 8; } - if (lastReadMinOne != UNSAFE.getLong(entry, (ENTRY_NAME_8) + step)) { + if (compare(readBuffer1, entry, ENTRY_NAME + step)) { return false; } - if (lastRead != UNSAFE.getLong(entry, (ENTRY_NAME_16) + step)) { - return false; - } - return true; - - } - - private boolean matchesEntryMedium(final byte[] entry) { - if (UNSAFE.getLong(entryStart) != UNSAFE.getLong(entry, ENTRY_NAME)) { - return false; - } - if (lastReadMinOne != UNSAFE.getLong(entry, ENTRY_NAME_8)) { - return false; - } - if (lastRead != UNSAFE.getLong(entry, ENTRY_NAME_16)) { + step += 8; + if (compare(readBuffer2, entry, ENTRY_NAME + step)) { return false; } return true; } - private boolean matchesEntryShort(final byte[] entry) { - if (lastReadMinOne != UNSAFE.getLong(entry, ENTRY_NAME)) { - return false; - } - if (lastRead != UNSAFE.getLong(entry, ENTRY_NAME_8)) { - return false; - } - return true; + private boolean matches16(final byte[] entry) { + return !compare(readBuffer1, entry, ENTRY_NAME) && + !compare(readBuffer2, entry, ENTRY_NAME + 8); } - - private boolean matchesEnding(final byte[] entry) { - return lastRead == UNSAFE.getLong(entry, ENTRY_NAME); - } - - private int length() { - return (int) (entryDelimiter - entryStart); - - } - } private static byte[][] processMemoryArea(final long startAddress, final long endAddress, boolean isFileStart) { @@ -456,8 +451,8 @@ public class CalculateAverage_royvanrijn { reader.processStart(); - if (!reader.readFirst()) { - // Found delimiter in first 8 bytes: + if (!reader.readNext()) { + // First 16 bytes: int temperature = reader.processEndAndGetTemperature(); @@ -466,13 +461,12 @@ public class CalculateAverage_royvanrijn { while (true) { entry = table[index]; if (entry == null) { - int length = reader.length(); byte[] entryBytes = (entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++] - : new byte[ENTRY_BASESIZE_WHITESPACE + length]; - table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature); + : new byte[ENTRY_BASESIZE_WHITESPACE + 16]; // with enough room + table[index] = fillEntry(entryBytes, reader.entryStart, 16, temperature, reader.readBuffer1, reader.readBuffer2); break; } - else if (reader.matchesEnding(entry)) { + else if (reader.matches16(entry)) { updateEntry(entry, temperature); break; } @@ -481,104 +475,45 @@ public class CalculateAverage_royvanrijn { index = (index + 1) & TABLE_MASK; } } + continue; } - else { - reader.processName(); + while (reader.readNext()) + ; - if (!reader.readNext()) { - // Found delimiter in 8-16 bytes: + int temperature = reader.processEndAndGetTemperature(); - int temperature = reader.processEndAndGetTemperature(); - - // Find or insert the entry: - int index = (int) (reader.hash & TABLE_MASK); - while (true) { - entry = table[index]; - if (entry == null) { - int length = reader.length(); - byte[] entryBytes = (entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++] - : new byte[ENTRY_BASESIZE_WHITESPACE + length]; - table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature); - break; - } - else if (reader.matchesEntryShort(entry)) { - updateEntry(entry, temperature); - break; - } - else { - // Move to the next index - index = (index + 1) & TABLE_MASK; - } - } + // Find or insert the entry: + int index = (int) (reader.hash & TABLE_MASK); + while (true) { + entry = table[index]; + if (entry == null) { + int length = reader.entryLength; + byte[] entryBytes = (length < PREMADE_MAX_SIZE && entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++] + : new byte[ENTRY_BASESIZE_WHITESPACE + length]; // with enough room + table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature, reader.readBuffer1, reader.readBuffer2); + break; + } + else if (reader.matches(entry)) { + updateEntry(entry, temperature); + break; } else { - reader.processName(); - - if (!reader.readNext()) { - // Found delimiter in 16-24 bytes: - - int temperature = reader.processEndAndGetTemperature(); - - // Find or insert the entry: - int index = (int) (reader.hash & TABLE_MASK); - while (true) { - entry = table[index]; - if (entry == null) { - int length = reader.length(); - byte[] entryBytes = (entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++] - : new byte[ENTRY_BASESIZE_WHITESPACE + length]; - table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature); - break; - } - else if (reader.matchesEntryMedium(entry)) { - updateEntry(entry, temperature); - break; - } - else { - // Move to the next index - index = (index + 1) & TABLE_MASK; - } - } - - } - else { - // Need more than 24 bytes: - - reader.processName(); - while (reader.readNext()) { - reader.processName(); - } - - int temperature = reader.processEndAndGetTemperature(); - - // Find or insert the entry: - int index = (int) (reader.hash & TABLE_MASK); - while (true) { - entry = table[index]; - if (entry == null) { - int length = reader.length(); - byte[] entryBytes = (length < PREMADE_MAX_SIZE && entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++] - : new byte[ENTRY_BASESIZE_WHITESPACE + length]; // with enough room - table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature); - break; - } - else if (reader.matchesEntryFull(entry)) { - updateEntry(entry, temperature); - break; - } - else { - // Move to the next index - index = (index + 1) & TABLE_MASK; - } - } - } + // Move to the next index + index = (index + 1) & TABLE_MASK; } } - } return table; } + private static boolean compare(final Object object1, final long address1, final Object object2, final long address2) { + return UNSAFE.getLong(object1, address1) != UNSAFE.getLong(object2, address2); + } + + private static boolean compare(final long value1, final Object object2, final long address2) { + return value1 != UNSAFE.getLong(object2, address2); + } + /* * `___` ___ ___ _ ___` ` ___ ` _ ` _ ` _` ___ * / ` \| _ \ __| \| \ \ / /_\ | | | | | | __|