diff --git a/calculate_average_ebarlas.sh b/calculate_average_ebarlas.sh index 9c5c215..422867d 100755 --- a/calculate_average_ebarlas.sh +++ b/calculate_average_ebarlas.sh @@ -16,4 +16,4 @@ # JAVA_OPTS="" -java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ebarlas measurements.txt 8 +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ebarlas diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ebarlas.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ebarlas.java index 63ff69f..b2a89d0 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_ebarlas.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ebarlas.java @@ -15,6 +15,8 @@ */ package dev.morling.onebrc; +import sun.misc.Unsafe; + import java.io.IOException; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; @@ -23,7 +25,6 @@ import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; -import java.util.Arrays; import java.util.List; import java.util.TreeMap; @@ -33,13 +34,22 @@ public class CalculateAverage_ebarlas { private static final int HASH_FACTOR = 433; private static final int HASH_TBL_SIZE = 16_383; // range of allowed hash values, inclusive - public static void main(String[] args) throws IOException, InterruptedException { - if (args.length != 2) { - System.out.println("Usage: java CalculateAverage "); - System.exit(1); + private static final Unsafe UNSAFE = makeUnsafe(); + + private static Unsafe makeUnsafe() { + try { + var f = Unsafe.class.getDeclaredField("theUnsafe"); + f.setAccessible(true); + return (Unsafe) f.get(null); } - var path = Paths.get(args[0]); - var numPartitions = Integer.parseInt(args[1]); + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + public static void main(String[] args) throws IOException, InterruptedException { + var path = Paths.get("measurements.txt"); + var numPartitions = Math.max(8, Runtime.getRuntime().availableProcessors()); var channel = FileChannel.open(path, StandardOpenOption.READ); var partitionSize = channel.size() / numPartitions; var partitions = new Partition[numPartitions]; @@ -75,13 +85,31 @@ public class CalculateAverage_ebarlas { var result = new TreeMap(); for (var st : stats) { if (st != null) { - var key = new String(st.key, StandardCharsets.UTF_8); + var key = new String(convert(st.keyAddr, st.keyLen, st.lastBytes), StandardCharsets.UTF_8); result.put(key, format(st)); } } System.out.println(result); } + private static byte[] convert(long keyAddr, int keyLen, int keyLastBytes) { + var len = keyLastBytes == 4 + ? keyLen * 4 // fully packed + : (keyLen - 1) * 4 + keyLastBytes; // last int partially packed + var bytes = new byte[len]; + var idx = 0; + for (long i = 0; i < keyLen; i++) { + var offset = i << 2; + var n = UNSAFE.getInt(keyAddr + offset); + var bound = i == keyLen - 1 ? keyLastBytes : 4; + for (int j = 0; j < bound; j++) { + bytes[idx++] = (byte) (n & 0xFF); + n >>>= 8; + } + } + return bytes; + } + private static String format(Stats st) { // adheres to expected output format return round(st.min / 10.0) + "/" + round((st.sum / 10.0) / st.count) + "/" + round(st.max / 10.0); } @@ -96,7 +124,7 @@ public class CalculateAverage_ebarlas { var current = partitions.get(i).stats; for (int j = 0; j < current.length; j++) { if (current[j] != null) { - var t = findInTable(target, current[j].hash, current[j].key, current[j].key.length); + var t = findInTable(target, current[j].hash, current[j].keyAddr, current[j].keyLen, current[j].lastBytes); t.min = Math.min(t.min, current[j].min); t.max = Math.max(t.max, current[j].max); t.sum += current[j].sum; @@ -112,7 +140,7 @@ public class CalculateAverage_ebarlas { var pNext = partitions.get(i); var pPrev = partitions.get(i - 1); var merged = mergeFooterAndHeader(pPrev.footer, pNext.header); - if (merged != null) { + if (merged != null && merged.length != 0) { if (merged[merged.length - 1] == '\n') { // fold into prev partition doProcessBuffer(ByteBuffer.wrap(merged).order(ByteOrder.LITTLE_ENDIAN), true, pPrev.stats); } @@ -148,80 +176,70 @@ public class CalculateAverage_ebarlas { } private static int reallyDoProcessBuffer(ByteBuffer buffer, Stats[] stats) { - var keyBuf = new byte[MAX_KEY_SIZE]; // buffer for key + long keyBaseAddr = UNSAFE.allocateMemory(MAX_KEY_SIZE); int keyStart = 0; // start of key in buffer used for footer calc try { // abort with exception to allow optimistic line processing while (true) { // one line per iteration keyStart = buffer.position(); // preserve line start - int n = buffer.getInt(); // first four bytes of key - byte b1 = (byte) (n & 0xFF); - byte b2 = (byte) ((n >> 8) & 0xFF); - byte b3 = (byte) ((n >> 16) & 0xFF); - byte b = (byte) ((n >> 24) & 0xFF); - int keyPos; - int keyHash = keyBuf[0] = b1; - if (b2 != ';' && b3 != ';') { // true for keys of length 3 or more - keyBuf[1] = b2; - keyBuf[2] = b3; - keyHash = HASH_FACTOR * (HASH_FACTOR * keyHash + b2) + b3; - keyPos = 3; - while (b != ';') { - keyHash = HASH_FACTOR * keyHash + b; - keyBuf[keyPos++] = b; - b = buffer.get(); + int keyHash = 0; // key hash code + long keyAddr = keyBaseAddr; // address for next int + int keyArrLen = 0; // number of key 4-byte ints + int keyLastBytes; // occupancy in last byte (1, 2, 3, or 4) + int val; // temperature value + while (true) { + int n = buffer.getInt(); + byte b0 = (byte) (n & 0xFF); + byte b1 = (byte) ((n >> 8) & 0xFF); + byte b2 = (byte) ((n >> 16) & 0xFF); + byte b3 = (byte) ((n >> 24) & 0xFF); + if (b0 == ';') { // ...;1.1 + val = getVal(buffer, b1, b2, b3, buffer.get()); + keyLastBytes = 4; + break; } - } - else { // slow path, rewind and consume byte-by-byte - buffer.position(keyStart + 1); - keyPos = 1; - while ((b = buffer.get()) != ';') { - keyHash = HASH_FACTOR * keyHash + b; - keyBuf[keyPos++] = b; + else if (b1 == ';') { // ...a;1.1 + val = getVal(buffer, b2, b3, buffer.get(), buffer.get()); + UNSAFE.putInt(keyAddr, b0); + keyLastBytes = 1; + keyArrLen++; + keyHash = HASH_FACTOR * keyHash + b0; + break; + } + else if (b2 == ';') { // ...ab;1.1 + val = getVal(buffer, b3, buffer.get(), buffer.get(), buffer.get()); + UNSAFE.putInt(keyAddr, n & 0x0000FFFF); + keyLastBytes = 2; + keyArrLen++; + keyHash = HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1; + break; + } + else if (b3 == ';') { // ...abc;1.1 + UNSAFE.putInt(keyAddr, n & 0x00FFFFFF); + keyLastBytes = 3; + keyArrLen++; + keyHash = HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1) + b2; + n = buffer.getInt(); + b0 = (byte) (n & 0xFF); + b1 = (byte) ((n >> 8) & 0xFF); + b2 = (byte) ((n >> 16) & 0xFF); + b3 = (byte) ((n >> 24) & 0xFF); + val = getVal(buffer, b0, b1, b2, b3); + break; + } + else { + UNSAFE.putInt(keyAddr, n); + keyArrLen++; + keyAddr += 4; + keyHash = HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1) + b2) + b3; } } var idx = keyHash & HASH_TBL_SIZE; var st = stats[idx]; if (st == null) { // nothing in table, eagerly claim spot - st = stats[idx] = newStats(keyBuf, keyPos, keyHash); + st = stats[idx] = newStats(keyBaseAddr, keyArrLen, keyLastBytes, keyHash); } - else if (!Arrays.equals(st.key, 0, st.key.length, keyBuf, 0, keyPos)) { - st = findInTable(stats, keyHash, keyBuf, keyPos); - } - var value = buffer.getInt(); - b = (byte) (value & 0xFF); // digit or dash - int val; - if (b == '-') { // dash branch - val = ((byte) ((value >> 8) & 0xFF)) - '0'; // digit after dash - b = (byte) ((value >> 16) & 0xFF); // second digit or decimal - if (b != '.') { // second digit - val = val * 10 + (b - '0'); // calc second digit - // skip decimal (at >> 24) - b = buffer.get(); // digit after decimal - val = val * 10 + (b - '0'); // calc digit after decimal - } - else { // decimal branch - // skip decimal (at >> 16) - b = (byte) ((value >> 24) & 0xFF); // digit after decimal - val = val * 10 + (b - '0'); // calc digit after decimal - } - buffer.get(); // newline - val = -val; - } - else { // first digit branch - val = b - '0'; // calc first digit - b = (byte) ((value >> 8) & 0xFF); // second digit or decimal - if (b != '.') { // second digit branch - val = val * 10 + (b - '0'); // calc second digit - // skip decimal (at >> 16) - b = (byte) ((value >> 24) & 0xFF); // digit after decimal - val = val * 10 + (b - '0'); // calc digit after decimal - buffer.get(); // newline - } - else { // decimal branch - b = (byte) ((value >> 16) & 0xFF); // digit after decimal - val = val * 10 + (b - '0'); // calc digit after decimal - // skip newline (at >> 24) - } + else if (!equals(st.keyAddr, st.keyLen, keyBaseAddr, keyArrLen)) { + st = findInTable(stats, keyHash, keyBaseAddr, keyArrLen, keyLastBytes); } st.min = Math.min(st.min, val); st.max = Math.max(st.max, val); @@ -235,23 +253,60 @@ public class CalculateAverage_ebarlas { return keyStart; } - private static Stats findInTable(Stats[] stats, int hash, byte[] key, int len) { // open-addressing scan + private static boolean equals(long key1, int len1, long key2, int len2) { + if (len1 != len2) { + return false; + } + for (long i = 0; i < len1; i++) { + var offset = i << 2; + if (UNSAFE.getInt(key1 + offset) != UNSAFE.getInt(key2 + offset)) { + return false; + } + } + return true; + } + + private static int getVal(ByteBuffer buffer, byte b0, byte b1, byte b2, byte b3) { + if (b0 == '-') { + if (b2 != '.') { // 6 bytes: -dd.dn + var b = buffer.get(); + buffer.get(); // newline + return -(((b1 - '0') * 10 + (b2 - '0')) * 10 + (b - '0')); + } + else { // 5 bytes: -d.dn + buffer.get(); // newline + return -((b1 - '0') * 10 + (b3 - '0')); + } + } + else { + if (b1 != '.') { // 5 bytes: dd.dn + buffer.get(); // newline + return ((b0 - '0') * 10 + (b1 - '0')) * 10 + (b3 - '0'); + } + else { // 4 bytes: d.dn + return (b0 - '0') * 10 + (b2 - '0'); + } + } + } + + private static Stats findInTable(Stats[] stats, int hash, long keyAddr, int keyLen, int keyLastBytes) { // open-addressing scan var idx = hash & HASH_TBL_SIZE; var st = stats[idx]; - while (st != null && !Arrays.equals(st.key, 0, st.key.length, key, 0, len)) { + while (st != null && !equals(st.keyAddr, st.keyLen, keyAddr, keyLen)) { idx = (idx + 1) % (HASH_TBL_SIZE + 1); st = stats[idx]; } if (st != null) { return st; } - return stats[idx] = newStats(key, len, hash); + return stats[idx] = newStats(keyAddr, keyLen, keyLastBytes, hash); } - private static Stats newStats(byte[] buffer, int len, int hash) { - var k = new byte[len]; - System.arraycopy(buffer, 0, k, 0, len); - return new Stats(k, hash); + private static Stats newStats(long keyAddr, int keyLen, int keyLastBytes, int hash) { + var bytes = keyLen << 2; + long k = UNSAFE.allocateMemory(bytes); + UNSAFE.copyMemory(keyAddr, k, bytes); + return new Stats(k, keyLen, keyLastBytes, hash); } private static byte[] readFooter(ByteBuffer buffer, int lineStart) { // read from line start to current pos (end-of-input) @@ -281,15 +336,19 @@ public class CalculateAverage_ebarlas { } private static class Stats { // min, max, and sum values are modeled with integral types that represent tenths of a unit - final byte[] key; + final long keyAddr; // address of 4-byte integer array + final int keyLen; // number of 4-byte integers starting at address + final int lastBytes; // number of bytes packed into last key int (1, 2, 3 or 4) final int hash; int min = Integer.MAX_VALUE; int max = Integer.MIN_VALUE; long sum; long count; - Stats(byte[] key, int hash) { - this.key = key; + Stats(long keyAddr, int keyLen, int lastBytes, int hash) { + this.keyAddr = keyAddr; + this.keyLen = keyLen; + this.lastBytes = lastBytes; this.hash = hash; } }