multithreaded version! (#415)

This commit is contained in:
Eve 2024-01-15 17:39:36 +00:00 committed by GitHub
parent 61f5618ff2
commit cd0e20b304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -22,10 +22,14 @@ import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.lang.Math; import java.lang.Math;
import java.util.Map;
import java.util.TreeMap;
public class CalculateAverage_netrunnereve { public class CalculateAverage_netrunnereve {
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static final int NUM_THREADS = 8; // test machine
private static final int LEN_EXTEND = 200; // guarantees a newline
private static class MeasurementAggregator { // min, max, sum stored as 0.1/unit private static class MeasurementAggregator { // min, max, sum stored as 0.1/unit
private MeasurementAggregator next = null; // linked list of entries for handling hash colisions private MeasurementAggregator next = null; // linked list of entries for handling hash colisions
@ -36,6 +40,12 @@ public class CalculateAverage_netrunnereve {
private int count = 0; private int count = 0;
} }
private static class ThreadCalcs {
private MeasurementAggregator[] hashSpace = null;
private String[] staArr = null;
private int numStations = 0;
}
// djb2 hash // djb2 hash
private static int calc_hash(byte[] input, int len) { private static int calc_hash(byte[] input, int len) {
int hash = 5831; int hash = 5831;
@ -45,23 +55,139 @@ public class CalculateAverage_netrunnereve {
return Math.abs(hash % 16384); return Math.abs(hash % 16384);
} }
public static void main(String[] args) { private static class ThreadedParser extends Thread {
try { private MappedByteBuffer mbuf;
RandomAccessFile mraf = new RandomAccessFile(FILE, "r"); private int mbs;
long fileSize = mraf.getChannel().size(); private ThreadCalcs[] threadOut;
long bufSize = Integer.MAX_VALUE; // Java requirement is <= Integer.MAX_VALUE private int threadID;
int numStations = 0;
private ThreadedParser(MappedByteBuffer mbuf, int mbs, ThreadCalcs[] threadOut, int threadID) {
this.mbuf = mbuf;
this.mbs = mbs;
this.threadOut = threadOut;
this.threadID = threadID;
}
public void run() {
MeasurementAggregator[] hashSpace = new MeasurementAggregator[16384]; // 14-bit hash MeasurementAggregator[] hashSpace = new MeasurementAggregator[16384]; // 14-bit hash
byte[] scratch = new byte[100]; // <= 100 characters in station name byte[] scratch = new byte[100]; // <= 100 characters in station name
String[] staArr = new String[10000]; // max 10000 station names String[] staArr = new String[10000]; // max 10000 station names
MeasurementAggregator ma = null; MeasurementAggregator ma = null;
int numStations = 0;
boolean state = false; // 0 for station pickup, 1 for measurement pickup
int negMul = 1;
int head = 0;
int tempCnt = -1; // 0 if 1 digit measurement, 1 if 2 digit
for (int i = 0; i < mbs; i++) {
byte cur = mbuf.get(i);
if (state == true) {
if (cur == 46) { // .
int tempa = mbuf.get(i + 1) - 48;
tempa += (scratch[0] - 48) * (10 + 90 * tempCnt) + (scratch[1] - 48) * (10 * tempCnt); // branchless
tempa *= negMul;
if (tempa < ma.min) {
ma.min = tempa;
}
if (tempa > ma.max) {
ma.max = tempa;
}
ma.sum += tempa;
ma.count++;
i += 2; // go to start of new line
state = false;
negMul = 1;
head = i + 1;
tempCnt = -1;
}
else if (cur == 45) { // ascii -
negMul = -1;
}
else {
scratch[tempCnt + 1] = cur;
tempCnt++;
}
}
else if (cur == 59) { // ;
int len = i - head;
// this is faster than filling scratch immediately after each byte is read
mbuf.position(head);
mbuf.get(scratch, 0, len);
int hash = calc_hash(scratch, len);
ma = hashSpace[hash];
MeasurementAggregator prev = null;
while (true) {
if (ma == null) {
ma = new MeasurementAggregator();
ma.station = Arrays.copyOfRange(scratch, 0, len);
staArr[numStations] = new String(scratch, 0, len, StandardCharsets.UTF_8);
if (prev != null) {
prev.next = ma;
}
else {
hashSpace[hash] = ma;
}
numStations++;
break;
}
else if ((len != ma.station.length) || (Arrays.compare(scratch, 0, len, ma.station, 0, len) != 0)) { // hash collision
prev = ma;
ma = ma.next;
}
else { // hit
break;
}
}
state = true;
head = i + 1;
}
}
threadOut[threadID] = new ThreadCalcs();
threadOut[threadID].hashSpace = hashSpace;
threadOut[threadID].staArr = staArr;
threadOut[threadID].numStations = numStations;
}
}
public static void main(String[] args) {
try {
RandomAccessFile mraf = new RandomAccessFile(FILE, "r");
long fileSize = mraf.getChannel().size();
long threadNum = NUM_THREADS;
long minThreads = (fileSize / Integer.MAX_VALUE) + 1; // minimum # of threads required due to MappedByteBuffer size limit
if (threadNum < minThreads) {
threadNum = minThreads;
}
long bufSize = fileSize / threadNum;
// don't bother multithreading for small files
if (bufSize < 1000000) {
threadNum = 1;
bufSize = Integer.MAX_VALUE;
}
ThreadedParser[] myThreads = new ThreadedParser[(int) threadNum];
ThreadCalcs[] threadOut = new ThreadCalcs[(int) threadNum];
int threadID = 0;
long h = 0; long h = 0;
while (h < fileSize) { while (h < fileSize) {
long length = bufSize; long length = bufSize;
boolean finished = false; boolean finished = false;
if (h + length > fileSize) {
if ((h == 0) && (length + LEN_EXTEND < Integer.MAX_VALUE)) { // add a bit of extra bytes to first thread to avoid generating new thread for the remainder
length += LEN_EXTEND; // arbitary bytes to guarantee a newline somewhere
}
if (h + length > fileSize) { // past the end
length = fileSize - h; length = fileSize - h;
finished = true; finished = true;
} }
@ -80,109 +206,69 @@ public class CalculateAverage_netrunnereve {
} }
} }
boolean state = false; // 0 for station pickup, 1 for measurement pickup myThreads[threadID] = new ThreadedParser(mbuf, mbs, threadOut, threadID);
int negMul = 1; myThreads[threadID].start();
int head = 0;
int tempCnt = -1; // 0 if 1 digit measurement, 1 if 2 digit
for (int i = 0; i < mbs; i++) {
byte cur = mbuf.get(i);
if (state == true) {
if (cur == 46) { // .
int tempa = mbuf.get(i + 1) - 48;
tempa += (scratch[0] - 48) * (10 + 90 * tempCnt) + (scratch[1] - 48) * (10 * tempCnt); // branchless
tempa *= negMul;
if (tempa < ma.min) {
ma.min = tempa;
}
if (tempa > ma.max) {
ma.max = tempa;
}
ma.sum += tempa;
ma.count++;
i += 2; // go to start of new line
state = false;
negMul = 1;
head = i + 1;
tempCnt = -1;
}
else if (cur == 45) { // ascii -
negMul = -1;
}
else {
scratch[tempCnt + 1] = cur;
tempCnt++;
}
}
else if (cur == 59) { // ;
int len = i - head;
// this is faster than filling scratch immediately after each byte is read
mbuf.position(head);
mbuf.get(scratch, 0, len);
int hash = calc_hash(scratch, len);
ma = hashSpace[hash];
MeasurementAggregator prev = null;
while (true) {
if (ma == null) {
ma = new MeasurementAggregator();
ma.station = Arrays.copyOfRange(scratch, 0, len);
staArr[numStations] = new String(scratch, 0, len, StandardCharsets.UTF_8);
if (prev != null) {
prev.next = ma;
}
else {
hashSpace[hash] = ma;
}
numStations++;
break;
}
else if ((len != ma.station.length) || (Arrays.compare(scratch, 0, len, ma.station, 0, len) != 0)) { // hash collision
prev = ma;
ma = ma.next;
}
else { // hit
break;
}
}
state = true;
head = i + 1;
}
}
h += mbs; h += mbs;
threadID++;
} }
Arrays.sort(staArr, 0, numStations); for (int i = 0; i < threadID; i++) {
try {
myThreads[i].join();
}
catch (InterruptedException ex) {
System.exit(1);
}
}
// use treemap to sort and uniquify
Map<String, Integer> staMap = new TreeMap<>();
for (int i = 0; i < threadID; i++) {
for (int j = 0; j < threadOut[i].numStations; j++) {
staMap.put(threadOut[i].staArr[j], 0);
}
}
boolean started = false;
String out = "{"; String out = "{";
for (int i = 0; i < numStations; i++) { for (String i : staMap.keySet()) {
byte[] strBuf = staArr[i].getBytes(StandardCharsets.UTF_8); if (started) {
out += ", ";
}
else {
started = true;
}
byte[] strBuf = i.getBytes(StandardCharsets.UTF_8);
int hash = calc_hash(strBuf, strBuf.length); int hash = calc_hash(strBuf, strBuf.length);
ma = hashSpace[hash]; MeasurementAggregator mSum = new MeasurementAggregator();
for (int j = 0; j < threadID; j++) {
MeasurementAggregator ma = threadOut[j].hashSpace[hash];
while (true) { while (true) {
if ((strBuf.length != ma.station.length) || (Arrays.compare(strBuf, ma.station) != 0)) { // hash collision if ((strBuf.length != ma.station.length) || (Arrays.compare(strBuf, ma.station) != 0)) { // hash collision
ma = ma.next; ma = ma.next;
continue; continue;
} }
else { // hit else { // hit
double min = Math.round(Double.valueOf(ma.min)) / 10.0; if (ma.min < mSum.min) {
double avg = Math.round(Double.valueOf(ma.sum) / Double.valueOf(ma.count)) / 10.0; mSum.min = ma.min;
double max = Math.round(Double.valueOf(ma.max)) / 10.0; }
out += staArr[i] + "=" + min + "/" + avg + "/" + max; if (ma.max > mSum.max) {
if (i != (numStations - 1)) { mSum.max = ma.max;
out += ", "; }
mSum.sum += ma.sum;
mSum.count += ma.count;
break;
} }
break;
} }
} }
double min = Math.round(Double.valueOf(mSum.min)) / 10.0;
double avg = Math.round(Double.valueOf(mSum.sum) / Double.valueOf(mSum.count)) / 10.0;
double max = Math.round(Double.valueOf(mSum.max)) / 10.0;
out += i + "=" + min + "/" + avg + "/" + max;
} }
out += "}\n"; out += "}\n";
System.out.print(out); System.out.print(out);