Updates for gamlerhart: Simpler & Faster (#580)

* Update with Rounding Bugfix

* Simplification of Merging Results

* More Plain Java Code for Value Storage

* Improve Performance by Stupid Hash

Drop around 3 seconds on my machine by
simplifying the hash to be ridicules stupid,
but faster.

* Fix outdated comment
This commit is contained in:
Roman Stoffel 2024-01-25 23:12:10 +01:00 committed by GitHub
parent b20e7365e7
commit 94e29982f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,7 +24,10 @@ import java.nio.ByteOrder;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Iterator;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import static java.lang.Double.doubleToRawLongBits; import static java.lang.Double.doubleToRawLongBits;
import static java.lang.Double.longBitsToDouble; import static java.lang.Double.longBitsToDouble;
@ -69,19 +72,16 @@ public class CalculateAverage_gamlerhart {
ArrayList<Section> sections = splitFileIntoSections(fileSize, fileContent); ArrayList<Section> sections = splitFileIntoSections(fileSize, fileContent);
var loopBound = byteVec.loopBound(fileSize) - vecLen; var loopBound = byteVec.loopBound(fileSize) - vecLen;
PrivateHashMap result = sections.stream() var result = sections.stream()
.parallel() .parallel()
.map(s -> { .map(s -> {
return parseSection(s.start, s.end, loopBound, fileContent); return parseSection(s.start, s.end, loopBound, fileContent);
}).reduce((mine, other) -> { });
assert mine != other;
mine.mergeFrom(fileContent, other);
return mine;
})
.get();
var measurements = new TreeMap<String, ResultRow>(); var measurements = new TreeMap<String, ResultRow>();
result.fill(fileContent, measurements); result.forEachOrdered(m -> {
m.fillMerge(fileContent, measurements);
});
System.out.println(measurements); System.out.println(measurements);
} }
} }
@ -160,11 +160,22 @@ public class CalculateAverage_gamlerhart {
// Encoding: // Encoding:
// - Key: long // - Key: long
// - 48 bits index, 16 bits length // - 48 bits index, 16 bits length
// - min: double final long[] keys = new long[SIZE];
// - max: double final Value[] values = new Value[SIZE];
// - sum: double
// - double: double private class Value {
final long[] keyValues = new long[SIZE * 5]; public Value(double min, double max, double sum, long count) {
this.min = min;
this.max = max;
this.sum = sum;
this.count = count;
}
public double min;
public double max;
public double sum;
public long count;
}
// int debug_size = 0; // int debug_size = 0;
@ -179,43 +190,40 @@ public class CalculateAverage_gamlerhart {
} }
private static int calculateHash(MemorySegment file, long pos, int len) { private static int calculateHash(MemorySegment file, long pos, int len) {
int hashCode = 1; if (len > 4) {
int i = 0; return file.get(INT_UNALIGNED_BIG_ENDIAN, pos) + 31 * len;
int intBound = (len / 4) * 4;
for (; i < intBound; i += 4) {
int v = file.get(INT_UNALIGNED_BIG_ENDIAN, pos + i);
hashCode = 31 * hashCode + v;
} }
for (; i < len; i++) { else {
int v = file.get(JAVA_BYTE, pos + i); int hashCode = len;
hashCode = 31 * hashCode + v; int i = 0;
for (; i < len; i++) {
int v = file.get(JAVA_BYTE, pos + i);
hashCode = 31 * hashCode + v;
}
return hashCode;
} }
return hashCode;
} }
private void doAdd(MemorySegment file, int hash, long pos, int len, double val) { private void doAdd(MemorySegment file, int hash, long pos, int len, double val) {
int slot = hash & MASK; int slot = hash & MASK;
for (var probe = 0; probe < 20000; probe++) { for (var probe = 0; probe < 20000; probe++) {
var iSl = ((slot + probe) & MASK) * 5; var iSl = ((slot + probe) & MASK);
var slotEntry = keyValues[iSl]; var slotEntry = keys[iSl];
var emtpy = slotEntry == 0; var emtpy = slotEntry == 0;
if (emtpy) { if (emtpy) {
long keyInfo = pos << SHIFT_POS | len; long keyInfo = pos << SHIFT_POS | len;
long valueBits = doubleToRawLongBits(val); keys[iSl] = keyInfo;
keyValues[iSl] = keyInfo; values[iSl] = new Value(val, val, val, 1);
keyValues[iSl + 1] = valueBits;
keyValues[iSl + 2] = valueBits;
keyValues[iSl + 3] = valueBits;
keyValues[iSl + 4] = 1;
// debug_size++; // debug_size++;
return; return;
} }
else if (isSameEntry(file, slotEntry, pos, len)) { else if (isSameEntry(file, slotEntry, pos, len)) {
keyValues[iSl + 1] = doubleToRawLongBits(Math.min(longBitsToDouble(keyValues[iSl + 1]), val)); var vE = values[iSl];
keyValues[iSl + 2] = doubleToRawLongBits(Math.max(longBitsToDouble(keyValues[iSl + 2]), val)); vE.min = Math.min(vE.min, val);
keyValues[iSl + 3] = doubleToRawLongBits(longBitsToDouble(keyValues[iSl + 3]) + val); vE.max = Math.max(vE.max, val);
keyValues[iSl + 4] = keyValues[iSl + 4] + 1; vE.sum = vE.sum + val;
vE.count++;
return; return;
} }
else { else {
@ -268,118 +276,65 @@ public class CalculateAverage_gamlerhart {
return true; return true;
} }
public PrivateHashMap mergeFrom(MemorySegment file, PrivateHashMap other) { public void fillMerge(MemorySegment file, TreeMap<String, ResultRow> treeMap) {
for (int slot = 0; slot < other.keyValues.length / 5; slot++) { for (int i = 0; i < keys.length; i++) {
int srcI = slot * 5; var ji = i;
long keyE = other.keyValues[srcI]; long keyE = keys[ji];
if (keyE != 0) {
long oPos = (keyE & MASK_POS) >> SHIFT_POS;
int oLen = (int) (keyE & MASK_LEN);
addMerge(file, other, srcI, oPos, oLen);
}
}
return this;
}
private void addMerge(MemorySegment file, PrivateHashMap other, int srcI, long oPos, int oLen) {
int slot = calculateHash(file, oPos, oLen) & MASK;
for (var probe = 0; probe < 20000; probe++) {
var iSl = ((slot + probe) & MASK) * 5;
var slotEntry = keyValues[iSl];
var emtpy = slotEntry == 0;
// var debugKey = new String(file.asSlice(oPos, oLen).toArray(JAVA_BYTE));
if (emtpy) {
// if (debugKey.equals("Cabo San Lucas")) {
// System.out.println("=> VALUES (init) " + debugKey + "@" + iSl + " max: " + longBitsToDouble(other.keyValues[srcI + 2]) + "," + longBitsToDouble(keyValues[iSl + 2]));
// }
keyValues[iSl] = other.keyValues[srcI];
keyValues[iSl + 1] = other.keyValues[srcI + 1];
keyValues[iSl + 2] = other.keyValues[srcI + 2];
keyValues[iSl + 3] = other.keyValues[srcI + 3];
keyValues[iSl + 4] = other.keyValues[srcI + 4];
// debug_size++;
return;
}
else if (isSameEntry(file, slotEntry, oPos, oLen)) {
// if (debugKey.equals("Cabo San Lucas")) {
// System.out.println("=> VALUES (merge) " + "@" + iSl + debugKey + " max: " + longBitsToDouble(other.keyValues[srcI + 2]) + ","
// + longBitsToDouble(keyValues[iSl + 2]) + "=> "
// + Math.max(longBitsToDouble(keyValues[iSl + 2]), longBitsToDouble(other.keyValues[srcI + 2])));
// }
keyValues[iSl + 1] = doubleToRawLongBits(Math.min(longBitsToDouble(keyValues[iSl + 1]), longBitsToDouble(other.keyValues[srcI + 1])));
keyValues[iSl + 2] = doubleToRawLongBits(Math.max(longBitsToDouble(keyValues[iSl + 2]), longBitsToDouble(other.keyValues[srcI + 2])));
keyValues[iSl + 3] = doubleToRawLongBits(longBitsToDouble(keyValues[iSl + 3]) + longBitsToDouble(other.keyValues[srcI + 3]));
keyValues[iSl + 4] = keyValues[iSl + 4] + other.keyValues[srcI + 4];
// if (debugKey.equals("Cabo San Lucas")) {
// System.out.println("=> VALUES (after-merge) self: "+ "@" + iSl + System.identityHashCode(this) + ":"+ debugKey + " max: " +
// + longBitsToDouble(keyValues[iSl + 2]) + "=> ");
// }
return;
}
else {
// long keyPos = (slotEntry & MASK_POS) >> SHIFT_POS;
// int keyLen = (int) (slotEntry & MASK_LEN);
// System.out.println("Colliding " + new String(file.asSlice(pos,len).toArray(ValueLayout.JAVA_BYTE)) +
// " with key" + new String(file.asSlice(keyPos,keyLen).toArray(ValueLayout.JAVA_BYTE)) +
// " hash " + hash + " slot " + slot + "+" + probe + " at " + iSl);
// debug_reprobeMax = Math.max(debug_reprobeMax, probe);
}
}
throw new IllegalStateException("More than 20000 reprobes");
}
public void fill(MemorySegment file, TreeMap<String, ResultRow> treeMap) {
for (int i = 0; i < keyValues.length / 5; i++) {
var ji = i * 5;
long keyE = keyValues[ji];
if (keyE != 0) { if (keyE != 0) {
long keyPos = (keyE & MASK_POS) >> SHIFT_POS; long keyPos = (keyE & MASK_POS) >> SHIFT_POS;
int keyLen = (int) (keyE & MASK_LEN); int keyLen = (int) (keyE & MASK_LEN);
byte[] keyBytes = new byte[keyLen]; byte[] keyBytes = new byte[keyLen];
MemorySegment.copy(file, JAVA_BYTE, keyPos, keyBytes, 0, keyLen); MemorySegment.copy(file, JAVA_BYTE, keyPos, keyBytes, 0, keyLen);
var key = new String(keyBytes); var key = new String(keyBytes);
var min = longBitsToDouble(keyValues[ji + 1]); var vE = values[ji];
var max = longBitsToDouble(keyValues[ji + 2]); var min = vE.min;
var sum = longBitsToDouble(keyValues[ji + 3]); var max = vE.max;
var count = keyValues[ji + 4]; var sum = vE.sum;
treeMap.put(key, new ResultRow(min, sum / count, max)); var count = vE.count;
treeMap.compute(key, (k, e) -> {
if (e == null) {
return new ResultRow(min, max, sum, count);
}
else {
return new ResultRow(Math.min(e.min, min), Math.max(e.max, max), e.sum + sum, e.count + count);
}
});
} }
} }
} }
public String debugPrint(MemorySegment file) { // public String debugPrint(MemorySegment file) {
StringBuilder b = new StringBuilder(); // StringBuilder b = new StringBuilder();
for (int i = 0; i < keyValues.length / 5; i++) { // for (int i = 0; i < keyValues.length / 5; i++) {
var ji = i * 5; // var ji = i * 5;
long keyE = keyValues[ji]; // long keyE = keyValues[ji];
if (keyE != 0) { // if (keyE != 0) {
long keyPos = (keyE & MASK_POS) >> SHIFT_POS; // long keyPos = (keyE & MASK_POS) >> SHIFT_POS;
int keyLen = (int) (keyE & MASK_LEN); // int keyLen = (int) (keyE & MASK_LEN);
byte[] keyBytes = new byte[keyLen]; // byte[] keyBytes = new byte[keyLen];
MemorySegment.copy(file, JAVA_BYTE, keyPos, keyBytes, 0, keyLen); // MemorySegment.copy(file, JAVA_BYTE, keyPos, keyBytes, 0, keyLen);
var key = new String(keyBytes); // var key = new String(keyBytes);
var min = longBitsToDouble(keyValues[ji + 1]); // var min = longBitsToDouble(keyValues[ji + 1]);
var max = longBitsToDouble(keyValues[ji + 2]); // var max = longBitsToDouble(keyValues[ji + 2]);
var sum = longBitsToDouble(keyValues[ji + 3]); // var sum = longBitsToDouble(keyValues[ji + 3]);
var count = keyValues[ji + 4]; // var count = keyValues[ji + 4];
b.append("{").append(key).append("@").append(ji) // b.append("{").append(key).append("@").append(ji)
.append(",").append(min) // .append(",").append(min)
.append(",").append(max) // .append(",").append(max)
.append(",").append(sum) // .append(",").append(sum)
.append(",").append(count).append("},"); // .append(",").append(count).append("},");
} // }
} // }
return b.toString(); // return b.toString();
} // }
} }
record Section(long start, long end) { record Section(long start, long end) {
} }
private static record ResultRow(double min, double mean, double max) { private static record ResultRow(double min, double max, double sum, long count) {
public String toString() { public String toString() {
return round(min) + "/" + round(mean) + "/" + round(max); return round(min) + "/" + round(((Math.round(sum * 10.0) / 10.0) / count)) + "/" + round(max);
} }
private double round(double value) { private double round(double value) {