From 27b9232b7d1475f76a3e43cb65e5e4eb84aaa1e3 Mon Sep 17 00:00:00 2001 From: gonix Date: Wed, 17 Jan 2024 19:48:05 +0200 Subject: [PATCH] CalculateAverage_gonix update (#461) Co-authored-by: Giedrius D --- .../onebrc/CalculateAverage_gonix.java | 172 ++++++++---------- 1 file changed, 71 insertions(+), 101 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java b/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java index 8349d00..90f4360 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java @@ -133,78 +133,65 @@ class Aggregator { int start = pos; int hash = 0; + long tail = 0; while (true) { - // This is a bit ugly, but it is faster than reading by byte. + // Seen this trick used in multiple other solutions. + // Nice breakdown here: https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord long tmpLong = buf.getLong(pos); - if ((tmpLong & 0xFF) == ';') { - break; + long match = tmpLong ^ 0x3B3B3B3B_3B3B3B3BL; // 3B == ';' + match = ((match - 0x01010101_01010101L) & (~match & 0x80808080_80808080L)); + if (match == 0) { + hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFFFFFF); + pos += 8; + continue; } - if (((tmpLong >>> 8) & 0xFF) == ';') { - hash = (33 * hash) ^ (int) (tmpLong & 0xFF); - pos += 1; - break; - } - if (((tmpLong >>> 16) & 0xFF) == ';') { - hash = (33 * hash) ^ (int) (tmpLong & 0xFFFF); - pos += 2; - break; - } - if (((tmpLong >>> 24) & 0xFF) == ';') { - hash = (33 * hash) ^ (int) (tmpLong & 0xFFFFFF); - pos += 3; - break; - } - if (((tmpLong >>> 32) & 0xFF) == ';') { - hash = (33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF); - pos += 4; - break; - } - if (((tmpLong >>> 40) & 0xFF) == ';') { - hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFF); - pos += 5; - break; - } - if (((tmpLong >>> 48) & 0xFF) == ';') { - hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFF); - pos += 6; - break; - } - if (((tmpLong >>> 56) & 0xFF) == ';') { - hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFFFF); - pos += 7; - break; - } - hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFFFFFF); - pos += 8; + + int tailBits = Long.numberOfTrailingZeros(match >>> 7); + long tailMask = ~(-1L << tailBits); + tail = tmpLong & tailMask; + hash = ((33 * hash) ^ (int) (tail & 0xFFFFFFFF)) + (int) ((tail >>> 33) & 0xFFFFFFFF); + pos += tailBits >> 3; + break; } hash = (33 * hash) ^ (hash >>> 15); - int len = pos - start; - assert (buf.get(pos) == ';') : "Expected ';'"; + int lenInLongs = (pos - start) >> 3; + long tailAndLen = (tail << 8) | (lenInLongs & 0xFF); + // assert (buf.get(pos) == ';') : "Expected ';'"; pos++; int measurement; { + // Seen this trick used in multiple other solutions. + // Looks like the original author is @merykitty. long tmpLong = buf.getLong(pos); - int sign = 1; - if ((tmpLong & 0xFF) == '-') { - sign = -1; - tmpLong >>>= 8; - pos++; - } - int value; - if (((tmpLong >>> 8) & 0xFF) == '.') { - value = (int) (((tmpLong & 0xFF) - '0') * 10 + (((tmpLong >>> 16) & 0xFF) - '0')); - pos += 4; - } - else { - value = (int) (((tmpLong & 0xFF) - '0') * 100 + (((tmpLong >>> 8) & 0xFF) - '0') * 10 + (((tmpLong >>> 24) & 0xFF) - '0')); - pos += 5; - } - measurement = sign * value; - } - assert (buf.get(pos - 1) == '\n') : "Expected '\\n'"; - add(buf, start, len, hash, measurement); + // The 4th binary digit of the ascii of a digit is 1 while + // that of the '.' is 0. This finds the decimal separator + // The value can be 12, 20, 28 + int decimalSepPos = Long.numberOfTrailingZeros(~tmpLong & 0x10101000); + int shift = 28 - decimalSepPos; + // signed is -1 if negative, 0 otherwise + long signed = (~tmpLong << 59) >> 63; + long designMask = ~(signed & 0xFF); + // Align the number to a specific position and transform the ascii code + // to actual digit value in each byte + long digits = ((tmpLong & designMask) << shift) & 0x0F000F0F00L; + + // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit) + // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) = + // 0x000000UU00TTHH00 + + // 0x00UU00TTHH000000 * 10 + + // 0xUU00TTHH00000000 * 100 + // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400 + // This results in our value lies in the bit 32 to 41 of this product + // That was close :) + long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; + measurement = (int) ((absValue ^ signed) - signed); + pos += (decimalSepPos >>> 3) + 3; + } + // assert (buf.get(pos - 1) == '\n') : "Expected '\\n'"; + + add(buf, start, tailAndLen, hash, measurement); } return this; @@ -216,13 +203,13 @@ class Aggregator { .mapToObj(offset -> new Entry(mem, offset)); } - private void add(ByteBuffer buf, int start, int len, int hash, int measurement) { + private void add(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) { int idx = hash & INDEX_MASK; while (true) { if (index[idx] != 0) { int offset = index[idx]; - if (keyEqual(offset, buf, start, len)) { - int pos = offset + (len >> 3) + 2; + if (keyEqual(offset, buf, start, tailAndLen)) { + int pos = offset + (int) (tailAndLen & 0xFF) + 1; mem[pos + FLD_MIN] = Math.min((int) measurement, (int) mem[pos + FLD_MIN]); mem[pos + FLD_MAX] = Math.max((int) measurement, (int) mem[pos + FLD_MAX]); mem[pos + FLD_SUM] += measurement; @@ -231,39 +218,27 @@ class Aggregator { } } else { - index[idx] = create(buf, start, len, hash, measurement); + index[idx] = create(buf, start, tailAndLen, hash, measurement); return; } idx = (idx + 1) & INDEX_MASK; } } - private int create(ByteBuffer buf, int start, int len, int hash, int measurement) { + private int create(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) { int offset = memUsed; - mem[offset] = len; + mem[offset] = tailAndLen; int memPos = offset + 1; - int memEndEarly = memPos + (len >> 3); + int memEnd = memPos + (int) (tailAndLen & 0xFF); int bufPos = start; - int bufEnd = start + len; - while (memPos < memEndEarly) { + while (memPos < memEnd) { mem[memPos] = buf.getLong(bufPos); memPos += 1; bufPos += 8; } - if (bufPos < bufEnd) { - int shift = (8 - (len & 7)) << 3; // (8 - (len % 8)) * 8 - long tmpLong = buf.getLong(bufPos) << shift >>> shift; - mem[memPos] = tmpLong; - } - else { - // "consume" extra long - makes math a bit simpler to calculate - // fields offset for update. - mem[memPos] = 0; - } - memPos += 1; mem[memPos + FLD_MIN] = measurement; mem[memPos + FLD_MAX] = measurement; mem[memPos + FLD_SUM] = measurement; @@ -273,28 +248,21 @@ class Aggregator { return offset; } - private boolean keyEqual(int offset, ByteBuffer buf, int start, int len) { - if (len != mem[offset]) { + private boolean keyEqual(int offset, ByteBuffer buf, int start, long tailAndLen) { + + if (mem[offset] != tailAndLen) { return false; } int memPos = offset + 1; - int memEndEarly = memPos + (len >> 3); + int memEnd = memPos + (int) (tailAndLen & 0xFF); int bufPos = start; - int bufEnd = start + len; - while (memPos < memEndEarly) { + while (memPos < memEnd) { if (mem[memPos] != buf.getLong(bufPos)) { return false; } memPos += 1; bufPos += 8; } - if (bufPos < bufEnd) { - int shift = (8 - (len & 7)) << 3; // (8 - (len % 8)) * 8 - long tmpLong = buf.getLong(bufPos) << shift >>> shift; - if (mem[memPos] != tmpLong) { - return false; - } - } return true; } @@ -311,19 +279,22 @@ class Aggregator { public String getKey() { if (key == null) { int pos = this.offset; - int keyLen = (int) mem[pos++]; - var tmpBuf = ByteBuffer.allocate(keyLen + 8).order(ByteOrder.nativeOrder()); - for (int i = 0; i < keyLen; i += 8) { + long tailAndLen = mem[pos++]; + int keyLen = (int) (tailAndLen & 0xFF); + var tmpBuf = ByteBuffer.allocate((keyLen << 3) + 8).order(ByteOrder.nativeOrder()); + for (int i = 0; i < keyLen; i++) { tmpBuf.putLong(mem[pos++]); } - key = new String(tmpBuf.array(), 0, keyLen, StandardCharsets.UTF_8); + long tail = tailAndLen >>> 8; + tmpBuf.putLong(tail); + int keyLenBytes = (keyLen << 3) + 8 - (Long.numberOfLeadingZeros(tail) >> 3); + key = new String(tmpBuf.array(), 0, keyLenBytes, StandardCharsets.UTF_8); } return key; } public Entry add(Entry other) { - int keyLen = (int) mem[offset]; - int fldOffset = (keyLen >> 3) + 2; + int fldOffset = (int) (mem[offset] & 0xFF) + 1; int pos = offset + fldOffset; int otherPos = other.offset + fldOffset; long[] otherMem = other.mem; @@ -340,8 +311,7 @@ class Aggregator { @Override public String toString() { - int keyLen = (int) mem[offset]; - int pos = offset + (keyLen >> 3) + 2; + int pos = offset + (int) (mem[offset] & 0xFF) + 1; return round(mem[pos + FLD_MIN]) + "/" + round(((double) mem[pos + FLD_SUM]) / mem[pos + FLD_COUNT]) + "/" + round(mem[pos + FLD_MAX]);