netrunnereve: more optimizations (#485)

This commit is contained in:
Eve 2024-01-19 20:44:22 +00:00 committed by GitHub
parent ce8fe41bd4
commit 144a6af164
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,15 +21,18 @@ import java.io.RandomAccessFile;
import java.nio.MappedByteBuffer; import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.lang.Math;
import java.util.Map; import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.CountDownLatch;
import java.lang.Math;
public class CalculateAverage_netrunnereve { public class CalculateAverage_netrunnereve {
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static final int NUM_THREADS = 8; // test machine private static final int NUM_THREADS = 8; // test machine
private static final int LEN_EXTEND = 200; // guarantees a newline private static final int LEN_EXTEND = 200; // guarantees a newline
private static final int HASHT_SIZE = 16384; // size of hash table, adjust tradeoff between colisions and cache utilization
private static final int DJB2_INIT = 5831;
private static class MeasurementAggregator { // min, max, sum stored as 0.1/unit private static class MeasurementAggregator { // min, max, sum stored as 0.1/unit
private MeasurementAggregator next = null; // linked list of entries for handling hash colisions private MeasurementAggregator next = null; // linked list of entries for handling hash colisions
@ -48,11 +51,11 @@ public class CalculateAverage_netrunnereve {
// djb2 hash // djb2 hash
private static int calc_hash(byte[] input, int len) { private static int calc_hash(byte[] input, int len) {
int hash = 5831; int hash = DJB2_INIT;
for (int i = 0; i < len; i++) { for (int i = 0; i < len; i++) {
hash = ((hash << 5) + hash) + Byte.toUnsignedInt(input[i]); hash = ((hash << 5) + hash) + Byte.toUnsignedInt(input[i]);
} }
return Math.abs(hash % 16384); return Math.abs(hash % HASHT_SIZE);
} }
private static class ThreadedParser extends Thread { private static class ThreadedParser extends Thread {
@ -60,65 +63,39 @@ public class CalculateAverage_netrunnereve {
private int mbs; private int mbs;
private ThreadCalcs[] threadOut; private ThreadCalcs[] threadOut;
private int threadID; private int threadID;
private CountDownLatch tpLatch;
private ThreadedParser(MappedByteBuffer mbuf, int mbs, ThreadCalcs[] threadOut, int threadID) { private ThreadedParser(MappedByteBuffer mbuf, int mbs, ThreadCalcs[] threadOut, int threadID, CountDownLatch tpLatch) {
this.mbuf = mbuf; this.mbuf = mbuf;
this.mbs = mbs; this.mbs = mbs;
this.threadOut = threadOut; this.threadOut = threadOut;
this.threadID = threadID; this.threadID = threadID;
this.tpLatch = tpLatch;
} }
public void run() { public void run() {
MeasurementAggregator[] hashSpace = new MeasurementAggregator[16384]; // 14-bit hash MeasurementAggregator[] hashSpace = new MeasurementAggregator[HASHT_SIZE]; // hash table
byte[] scratch = new byte[100]; // <= 100 characters in station name byte[] scratch = new byte[100]; // <= 100 characters in station name
String[] staArr = new String[10000]; // max 10000 station names String[] staArr = new String[10000]; // max 10000 station names
MeasurementAggregator ma = null; MeasurementAggregator ma = null;
int numStations = 0; int numStations = 0;
boolean state = false; // 0 for station pickup, 1 for measurement pickup
int negMul = 1; int negMul = 1;
int head = 0; int head = 0;
int tempCnt = -1; // 0 if 1 digit measurement, 1 if 2 digit int tempCnt = -1; // 0 if 1 digit measurement, 1 if 2 digit
int hash = DJB2_INIT; // do calc_hash manually in loop
for (int i = 0; i < mbs; i++) { int i = 0; // byte by byte iterator
while (true) {
byte cur = mbuf.get(i); byte cur = mbuf.get(i);
if (state == true) { if (cur == 59) { // ;
if (cur == 46) { // . hash = Math.abs(hash % HASHT_SIZE);
int tempa = mbuf.get(i + 1) - 48;
tempa += (scratch[0] - 48) * (10 + 90 * tempCnt) + (scratch[1] - 48) * (10 * tempCnt); // branchless
tempa *= negMul;
if (tempa < ma.min) {
ma.min = tempa;
}
if (tempa > ma.max) {
ma.max = tempa;
}
ma.sum += tempa;
ma.count++;
i += 2; // go to start of new line
state = false;
negMul = 1;
head = i + 1;
tempCnt = -1;
}
else if (cur == 45) { // ascii -
negMul = -1;
}
else {
scratch[tempCnt + 1] = cur;
tempCnt++;
}
}
else if (cur == 59) { // ;
int len = i - head;
// this is faster than filling scratch immediately after each byte is read // this is faster than filling scratch immediately after each byte is read
int len = i - head;
mbuf.position(head); mbuf.position(head);
mbuf.get(scratch, 0, len); mbuf.get(scratch, 0, len);
int hash = calc_hash(scratch, len);
ma = hashSpace[hash]; ma = hashSpace[hash];
MeasurementAggregator prev = null; MeasurementAggregator prev = null;
@ -146,14 +123,53 @@ public class CalculateAverage_netrunnereve {
break; break;
} }
} }
state = true;
head = i + 1; i++;
while (true) {
cur = mbuf.get(i);
if (cur == 46) { // .
int tempa = (negMul) * ((10 + 90 * tempCnt) * (scratch[0] - 48) + (10 * tempCnt) * (scratch[1] - 48) + (mbuf.get(i + 1) - 48)); // branchless
if (tempa < ma.min) {
ma.min = tempa;
}
if (tempa > ma.max) {
ma.max = tempa;
}
ma.sum += tempa;
ma.count++;
// this line is finished!
i += 2; // newline char
hash = DJB2_INIT;
negMul = 1;
head = i + 1; // start of next line
tempCnt = -1;
break;
}
else if (cur == 45) { // ascii -
negMul = -1;
}
else {
scratch[tempCnt + 1] = cur;
tempCnt++;
}
i++;
}
if (head >= mbs) {
break;
}
} }
else {
hash = ((hash << 5) + hash) + Byte.toUnsignedInt(cur);
}
i++;
} }
threadOut[threadID] = new ThreadCalcs(); threadOut[threadID] = new ThreadCalcs();
threadOut[threadID].hashSpace = hashSpace; threadOut[threadID].hashSpace = hashSpace;
threadOut[threadID].staArr = staArr; threadOut[threadID].staArr = staArr;
threadOut[threadID].numStations = numStations; threadOut[threadID].numStations = numStations;
tpLatch.countDown();
} }
} }
@ -175,8 +191,8 @@ public class CalculateAverage_netrunnereve {
bufSize = Integer.MAX_VALUE; bufSize = Integer.MAX_VALUE;
} }
ThreadedParser[] myThreads = new ThreadedParser[(int) threadNum];
ThreadCalcs[] threadOut = new ThreadCalcs[(int) threadNum]; ThreadCalcs[] threadOut = new ThreadCalcs[(int) threadNum];
CountDownLatch tpLatch = new CountDownLatch((int) threadNum);
int threadID = 0; int threadID = 0;
long h = 0; long h = 0;
@ -206,27 +222,25 @@ public class CalculateAverage_netrunnereve {
} }
} }
myThreads[threadID] = new ThreadedParser(mbuf, mbs, threadOut, threadID); ThreadedParser tpThr = new ThreadedParser(mbuf, mbs, threadOut, threadID, tpLatch);
myThreads[threadID].start(); tpThr.start();
h += mbs; h += mbs;
threadID++; threadID++;
} }
for (int i = 0; i < threadID; i++) { try {
try { tpLatch.await();
myThreads[i].join(); }
} catch (InterruptedException ex) {
catch (InterruptedException ex) { System.exit(1);
System.exit(1);
}
} }
// use treemap to sort and uniquify // use treemap to sort and uniquify
Map<String, Integer> staMap = new TreeMap<>(); Map<String, Boolean> staMap = new TreeMap<>();
for (int i = 0; i < threadID; i++) { for (int i = 0; i < threadID; i++) {
for (int j = 0; j < threadOut[i].numStations; j++) { for (int j = 0; j < threadOut[i].numStations; j++) {
staMap.put(threadOut[i].staArr[j], 0); staMap.put(threadOut[i].staArr[j], false);
} }
} }