diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_netrunnereve.java b/src/main/java/dev/morling/onebrc/CalculateAverage_netrunnereve.java index e323a32..13919cf 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_netrunnereve.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_netrunnereve.java @@ -21,15 +21,18 @@ import java.io.RandomAccessFile; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; -import java.lang.Math; import java.util.Map; import java.util.TreeMap; +import java.util.concurrent.CountDownLatch; +import java.lang.Math; public class CalculateAverage_netrunnereve { private static final String FILE = "./measurements.txt"; private static final int NUM_THREADS = 8; // test machine private static final int LEN_EXTEND = 200; // guarantees a newline + private static final int HASHT_SIZE = 16384; // size of hash table, adjust tradeoff between colisions and cache utilization + private static final int DJB2_INIT = 5831; private static class MeasurementAggregator { // min, max, sum stored as 0.1/unit private MeasurementAggregator next = null; // linked list of entries for handling hash colisions @@ -48,11 +51,11 @@ public class CalculateAverage_netrunnereve { // djb2 hash private static int calc_hash(byte[] input, int len) { - int hash = 5831; + int hash = DJB2_INIT; for (int i = 0; i < len; i++) { hash = ((hash << 5) + hash) + Byte.toUnsignedInt(input[i]); } - return Math.abs(hash % 16384); + return Math.abs(hash % HASHT_SIZE); } private static class ThreadedParser extends Thread { @@ -60,65 +63,39 @@ public class CalculateAverage_netrunnereve { private int mbs; private ThreadCalcs[] threadOut; private int threadID; + private CountDownLatch tpLatch; - private ThreadedParser(MappedByteBuffer mbuf, int mbs, ThreadCalcs[] threadOut, int threadID) { + private ThreadedParser(MappedByteBuffer mbuf, int mbs, ThreadCalcs[] threadOut, int threadID, CountDownLatch tpLatch) { this.mbuf = mbuf; this.mbs = mbs; this.threadOut = threadOut; this.threadID = threadID; + this.tpLatch = tpLatch; } public void run() { - MeasurementAggregator[] hashSpace = new MeasurementAggregator[16384]; // 14-bit hash + MeasurementAggregator[] hashSpace = new MeasurementAggregator[HASHT_SIZE]; // hash table byte[] scratch = new byte[100]; // <= 100 characters in station name String[] staArr = new String[10000]; // max 10000 station names MeasurementAggregator ma = null; int numStations = 0; - boolean state = false; // 0 for station pickup, 1 for measurement pickup int negMul = 1; int head = 0; int tempCnt = -1; // 0 if 1 digit measurement, 1 if 2 digit + int hash = DJB2_INIT; // do calc_hash manually in loop - for (int i = 0; i < mbs; i++) { + int i = 0; // byte by byte iterator + while (true) { byte cur = mbuf.get(i); - if (state == true) { - if (cur == 46) { // . - int tempa = mbuf.get(i + 1) - 48; - tempa += (scratch[0] - 48) * (10 + 90 * tempCnt) + (scratch[1] - 48) * (10 * tempCnt); // branchless - tempa *= negMul; - - if (tempa < ma.min) { - ma.min = tempa; - } - if (tempa > ma.max) { - ma.max = tempa; - } - ma.sum += tempa; - ma.count++; - - i += 2; // go to start of new line - state = false; - negMul = 1; - head = i + 1; - tempCnt = -1; - } - else if (cur == 45) { // ascii - - negMul = -1; - } - else { - scratch[tempCnt + 1] = cur; - tempCnt++; - } - } - else if (cur == 59) { // ; - int len = i - head; + if (cur == 59) { // ; + hash = Math.abs(hash % HASHT_SIZE); // this is faster than filling scratch immediately after each byte is read + int len = i - head; mbuf.position(head); mbuf.get(scratch, 0, len); - int hash = calc_hash(scratch, len); ma = hashSpace[hash]; MeasurementAggregator prev = null; @@ -146,14 +123,53 @@ public class CalculateAverage_netrunnereve { break; } } - state = true; - head = i + 1; + + i++; + while (true) { + cur = mbuf.get(i); + if (cur == 46) { // . + int tempa = (negMul) * ((10 + 90 * tempCnt) * (scratch[0] - 48) + (10 * tempCnt) * (scratch[1] - 48) + (mbuf.get(i + 1) - 48)); // branchless + + if (tempa < ma.min) { + ma.min = tempa; + } + if (tempa > ma.max) { + ma.max = tempa; + } + ma.sum += tempa; + ma.count++; + + // this line is finished! + i += 2; // newline char + hash = DJB2_INIT; + negMul = 1; + head = i + 1; // start of next line + tempCnt = -1; + break; + } + else if (cur == 45) { // ascii - + negMul = -1; + } + else { + scratch[tempCnt + 1] = cur; + tempCnt++; + } + i++; + } + if (head >= mbs) { + break; + } } + else { + hash = ((hash << 5) + hash) + Byte.toUnsignedInt(cur); + } + i++; } threadOut[threadID] = new ThreadCalcs(); threadOut[threadID].hashSpace = hashSpace; threadOut[threadID].staArr = staArr; threadOut[threadID].numStations = numStations; + tpLatch.countDown(); } } @@ -175,8 +191,8 @@ public class CalculateAverage_netrunnereve { bufSize = Integer.MAX_VALUE; } - ThreadedParser[] myThreads = new ThreadedParser[(int) threadNum]; ThreadCalcs[] threadOut = new ThreadCalcs[(int) threadNum]; + CountDownLatch tpLatch = new CountDownLatch((int) threadNum); int threadID = 0; long h = 0; @@ -206,27 +222,25 @@ public class CalculateAverage_netrunnereve { } } - myThreads[threadID] = new ThreadedParser(mbuf, mbs, threadOut, threadID); - myThreads[threadID].start(); + ThreadedParser tpThr = new ThreadedParser(mbuf, mbs, threadOut, threadID, tpLatch); + tpThr.start(); h += mbs; threadID++; } - for (int i = 0; i < threadID; i++) { - try { - myThreads[i].join(); - } - catch (InterruptedException ex) { - System.exit(1); - } + try { + tpLatch.await(); + } + catch (InterruptedException ex) { + System.exit(1); } // use treemap to sort and uniquify - Map staMap = new TreeMap<>(); + Map staMap = new TreeMap<>(); for (int i = 0; i < threadID; i++) { for (int j = 0; j < threadOut[i].numStations; j++) { - staMap.put(threadOut[i].staArr[j], 0); + staMap.put(threadOut[i].staArr[j], false); } }