From cd0e20b304d92b2e0774cf20bd6b3c0ccae9b21f Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Mon, 15 Jan 2024 17:39:36 +0000 Subject: [PATCH] multithreaded version! (#415) --- .../onebrc/CalculateAverage_netrunnereve.java | 282 ++++++++++++------ 1 file changed, 184 insertions(+), 98 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_netrunnereve.java b/src/main/java/dev/morling/onebrc/CalculateAverage_netrunnereve.java index 7ff3cdd..e323a32 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_netrunnereve.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_netrunnereve.java @@ -22,10 +22,14 @@ import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.lang.Math; +import java.util.Map; +import java.util.TreeMap; public class CalculateAverage_netrunnereve { private static final String FILE = "./measurements.txt"; + private static final int NUM_THREADS = 8; // test machine + private static final int LEN_EXTEND = 200; // guarantees a newline private static class MeasurementAggregator { // min, max, sum stored as 0.1/unit private MeasurementAggregator next = null; // linked list of entries for handling hash colisions @@ -36,6 +40,12 @@ public class CalculateAverage_netrunnereve { private int count = 0; } + private static class ThreadCalcs { + private MeasurementAggregator[] hashSpace = null; + private String[] staArr = null; + private int numStations = 0; + } + // djb2 hash private static int calc_hash(byte[] input, int len) { int hash = 5831; @@ -45,23 +55,139 @@ public class CalculateAverage_netrunnereve { return Math.abs(hash % 16384); } - public static void main(String[] args) { - try { - RandomAccessFile mraf = new RandomAccessFile(FILE, "r"); - long fileSize = mraf.getChannel().size(); - long bufSize = Integer.MAX_VALUE; // Java requirement is <= Integer.MAX_VALUE - int numStations = 0; + private static class ThreadedParser extends Thread { + private MappedByteBuffer mbuf; + private int mbs; + private ThreadCalcs[] threadOut; + private int threadID; + private ThreadedParser(MappedByteBuffer mbuf, int mbs, ThreadCalcs[] threadOut, int threadID) { + this.mbuf = mbuf; + this.mbs = mbs; + this.threadOut = threadOut; + this.threadID = threadID; + } + + public void run() { MeasurementAggregator[] hashSpace = new MeasurementAggregator[16384]; // 14-bit hash byte[] scratch = new byte[100]; // <= 100 characters in station name String[] staArr = new String[10000]; // max 10000 station names MeasurementAggregator ma = null; + int numStations = 0; + boolean state = false; // 0 for station pickup, 1 for measurement pickup + int negMul = 1; + int head = 0; + int tempCnt = -1; // 0 if 1 digit measurement, 1 if 2 digit + + for (int i = 0; i < mbs; i++) { + byte cur = mbuf.get(i); + if (state == true) { + if (cur == 46) { // . + int tempa = mbuf.get(i + 1) - 48; + tempa += (scratch[0] - 48) * (10 + 90 * tempCnt) + (scratch[1] - 48) * (10 * tempCnt); // branchless + tempa *= negMul; + + if (tempa < ma.min) { + ma.min = tempa; + } + if (tempa > ma.max) { + ma.max = tempa; + } + ma.sum += tempa; + ma.count++; + + i += 2; // go to start of new line + state = false; + negMul = 1; + head = i + 1; + tempCnt = -1; + } + else if (cur == 45) { // ascii - + negMul = -1; + } + else { + scratch[tempCnt + 1] = cur; + tempCnt++; + } + } + else if (cur == 59) { // ; + int len = i - head; + + // this is faster than filling scratch immediately after each byte is read + mbuf.position(head); + mbuf.get(scratch, 0, len); + + int hash = calc_hash(scratch, len); + ma = hashSpace[hash]; + MeasurementAggregator prev = null; + + while (true) { + if (ma == null) { + ma = new MeasurementAggregator(); + ma.station = Arrays.copyOfRange(scratch, 0, len); + staArr[numStations] = new String(scratch, 0, len, StandardCharsets.UTF_8); + + if (prev != null) { + prev.next = ma; + } + else { + hashSpace[hash] = ma; + } + + numStations++; + break; + } + else if ((len != ma.station.length) || (Arrays.compare(scratch, 0, len, ma.station, 0, len) != 0)) { // hash collision + prev = ma; + ma = ma.next; + } + else { // hit + break; + } + } + state = true; + head = i + 1; + } + } + threadOut[threadID] = new ThreadCalcs(); + threadOut[threadID].hashSpace = hashSpace; + threadOut[threadID].staArr = staArr; + threadOut[threadID].numStations = numStations; + } + } + + public static void main(String[] args) { + try { + RandomAccessFile mraf = new RandomAccessFile(FILE, "r"); + long fileSize = mraf.getChannel().size(); + long threadNum = NUM_THREADS; + + long minThreads = (fileSize / Integer.MAX_VALUE) + 1; // minimum # of threads required due to MappedByteBuffer size limit + if (threadNum < minThreads) { + threadNum = minThreads; + } + long bufSize = fileSize / threadNum; + + // don't bother multithreading for small files + if (bufSize < 1000000) { + threadNum = 1; + bufSize = Integer.MAX_VALUE; + } + + ThreadedParser[] myThreads = new ThreadedParser[(int) threadNum]; + ThreadCalcs[] threadOut = new ThreadCalcs[(int) threadNum]; + int threadID = 0; + long h = 0; while (h < fileSize) { long length = bufSize; boolean finished = false; - if (h + length > fileSize) { + + if ((h == 0) && (length + LEN_EXTEND < Integer.MAX_VALUE)) { // add a bit of extra bytes to first thread to avoid generating new thread for the remainder + length += LEN_EXTEND; // arbitary bytes to guarantee a newline somewhere + } + if (h + length > fileSize) { // past the end length = fileSize - h; finished = true; } @@ -80,109 +206,69 @@ public class CalculateAverage_netrunnereve { } } - boolean state = false; // 0 for station pickup, 1 for measurement pickup - int negMul = 1; - int head = 0; - int tempCnt = -1; // 0 if 1 digit measurement, 1 if 2 digit + myThreads[threadID] = new ThreadedParser(mbuf, mbs, threadOut, threadID); + myThreads[threadID].start(); - for (int i = 0; i < mbs; i++) { - byte cur = mbuf.get(i); - if (state == true) { - if (cur == 46) { // . - int tempa = mbuf.get(i + 1) - 48; - tempa += (scratch[0] - 48) * (10 + 90 * tempCnt) + (scratch[1] - 48) * (10 * tempCnt); // branchless - tempa *= negMul; - - if (tempa < ma.min) { - ma.min = tempa; - } - if (tempa > ma.max) { - ma.max = tempa; - } - ma.sum += tempa; - ma.count++; - - i += 2; // go to start of new line - state = false; - negMul = 1; - head = i + 1; - tempCnt = -1; - } - else if (cur == 45) { // ascii - - negMul = -1; - } - else { - scratch[tempCnt + 1] = cur; - tempCnt++; - } - } - else if (cur == 59) { // ; - int len = i - head; - - // this is faster than filling scratch immediately after each byte is read - mbuf.position(head); - mbuf.get(scratch, 0, len); - - int hash = calc_hash(scratch, len); - ma = hashSpace[hash]; - MeasurementAggregator prev = null; - - while (true) { - if (ma == null) { - ma = new MeasurementAggregator(); - ma.station = Arrays.copyOfRange(scratch, 0, len); - staArr[numStations] = new String(scratch, 0, len, StandardCharsets.UTF_8); - - if (prev != null) { - prev.next = ma; - } - else { - hashSpace[hash] = ma; - } - - numStations++; - break; - } - else if ((len != ma.station.length) || (Arrays.compare(scratch, 0, len, ma.station, 0, len) != 0)) { // hash collision - prev = ma; - ma = ma.next; - } - else { // hit - break; - } - } - state = true; - head = i + 1; - } - } h += mbs; + threadID++; } - Arrays.sort(staArr, 0, numStations); + for (int i = 0; i < threadID; i++) { + try { + myThreads[i].join(); + } + catch (InterruptedException ex) { + System.exit(1); + } + } + // use treemap to sort and uniquify + Map staMap = new TreeMap<>(); + for (int i = 0; i < threadID; i++) { + for (int j = 0; j < threadOut[i].numStations; j++) { + staMap.put(threadOut[i].staArr[j], 0); + } + } + + boolean started = false; String out = "{"; - for (int i = 0; i < numStations; i++) { - byte[] strBuf = staArr[i].getBytes(StandardCharsets.UTF_8); + for (String i : staMap.keySet()) { + if (started) { + out += ", "; + } + else { + started = true; + } + + byte[] strBuf = i.getBytes(StandardCharsets.UTF_8); int hash = calc_hash(strBuf, strBuf.length); - ma = hashSpace[hash]; + MeasurementAggregator mSum = new MeasurementAggregator(); + for (int j = 0; j < threadID; j++) { + MeasurementAggregator ma = threadOut[j].hashSpace[hash]; - while (true) { - if ((strBuf.length != ma.station.length) || (Arrays.compare(strBuf, ma.station) != 0)) { // hash collision - ma = ma.next; - continue; - } - else { // hit - double min = Math.round(Double.valueOf(ma.min)) / 10.0; - double avg = Math.round(Double.valueOf(ma.sum) / Double.valueOf(ma.count)) / 10.0; - double max = Math.round(Double.valueOf(ma.max)) / 10.0; - out += staArr[i] + "=" + min + "/" + avg + "/" + max; - if (i != (numStations - 1)) { - out += ", "; + while (true) { + if ((strBuf.length != ma.station.length) || (Arrays.compare(strBuf, ma.station) != 0)) { // hash collision + ma = ma.next; + continue; + } + else { // hit + if (ma.min < mSum.min) { + mSum.min = ma.min; + } + if (ma.max > mSum.max) { + mSum.max = ma.max; + } + mSum.sum += ma.sum; + mSum.count += ma.count; + break; } - break; } } + double min = Math.round(Double.valueOf(mSum.min)) / 10.0; + double avg = Math.round(Double.valueOf(mSum.sum) / Double.valueOf(mSum.count)) / 10.0; + double max = Math.round(Double.valueOf(mSum.max)) / 10.0; + out += i + "=" + min + "/" + avg + "/" + max; } out += "}\n"; System.out.print(out);