Deploy v2 for parkertimmins (#524)
* Deploy v2 for parkertimmins Main changes: - fix hash which masked incorrectly - do station equality check in simd - make station array length multiple of 32 - search for newline rather than semicolon * Fix bug - entries were being skipped between batches At the boundary between two batches, the first batch would stop after crossing a limit with a padding of 200 characters applied. The next batch should then start looking for the first full entry after the padding. This padding logic had been removed when starting a batch. For this reason, entries starting in the 200 character padding between batches were skipped.
This commit is contained in:
parent
d858959097
commit
c886aaba34
@ -16,28 +16,21 @@
|
|||||||
package dev.morling.onebrc;
|
package dev.morling.onebrc;
|
||||||
|
|
||||||
import jdk.incubator.vector.ByteVector;
|
import jdk.incubator.vector.ByteVector;
|
||||||
import jdk.incubator.vector.VectorMask;
|
|
||||||
import jdk.incubator.vector.VectorOperators;
|
|
||||||
|
|
||||||
import java.lang.foreign.Arena;
|
import java.lang.foreign.Arena;
|
||||||
import java.lang.foreign.MemorySegment;
|
import java.lang.foreign.MemorySegment;
|
||||||
|
|
||||||
import java.lang.foreign.ValueLayout;
|
import java.lang.foreign.ValueLayout;
|
||||||
import java.lang.reflect.Array;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.RandomAccessFile;
|
import java.io.RandomAccessFile;
|
||||||
import java.nio.MappedByteBuffer;
|
|
||||||
import java.nio.channels.FileChannel;
|
import java.nio.channels.FileChannel;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
import java.util.zip.CRC32C;
|
|
||||||
|
|
||||||
public class CalculateAverage_parkertimmins {
|
public class CalculateAverage_parkertimmins {
|
||||||
private static final String FILE = "./measurements.txt";
|
private static final String FILE = "./measurements.txt";
|
||||||
// private static final String FILE = "./full_measurements.no_license";
|
|
||||||
|
|
||||||
private static record ResultRow(double min, double mean, double max) {
|
private static record ResultRow(double min, double mean, double max) {
|
||||||
public String toString() {
|
public String toString() {
|
||||||
@ -51,14 +44,16 @@ public class CalculateAverage_parkertimmins {
|
|||||||
|
|
||||||
static class OpenHashTable {
|
static class OpenHashTable {
|
||||||
static class Entry {
|
static class Entry {
|
||||||
|
|
||||||
|
// key always stored as multiple of 32 bytes
|
||||||
byte[] key;
|
byte[] key;
|
||||||
short min;
|
byte keyLen;
|
||||||
short max;
|
short min = Short.MAX_VALUE;
|
||||||
|
short max = Short.MIN_VALUE;
|
||||||
long sum = 0;
|
long sum = 0;
|
||||||
long count = 0;
|
long count = 0;
|
||||||
int hash;
|
|
||||||
|
|
||||||
void merge(OpenHashTable.Entry other) {
|
void merge(Entry other) {
|
||||||
min = (short) Math.min(min, other.min);
|
min = (short) Math.min(min, other.min);
|
||||||
max = (short) Math.max(max, other.max);
|
max = (short) Math.max(max, other.max);
|
||||||
sum += other.sum;
|
sum += other.sum;
|
||||||
@ -80,15 +75,20 @@ public class CalculateAverage_parkertimmins {
|
|||||||
// key not present, so add it
|
// key not present, so add it
|
||||||
if (entry == null) {
|
if (entry == null) {
|
||||||
entry = entries[idx] = new Entry();
|
entry = entries[idx] = new Entry();
|
||||||
entry.key = Arrays.copyOf(buf, sLen);
|
|
||||||
|
int rem = sLen % 32;
|
||||||
|
int arrayLen = rem == 0 ? sLen : sLen + 32 - rem;
|
||||||
|
entry.key = Arrays.copyOf(buf, arrayLen);
|
||||||
|
Arrays.fill(entry.key, sLen, arrayLen, (byte) 0);
|
||||||
|
entry.keyLen = (byte) sLen;
|
||||||
|
|
||||||
entry.min = entry.max = val;
|
entry.min = entry.max = val;
|
||||||
entry.sum += val;
|
entry.sum += val;
|
||||||
entry.count++;
|
entry.count++;
|
||||||
entry.hash = hash;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if (entry.hash == hash && entry.key.length == sLen && Arrays.equals(entry.key, 0, sLen, buf, 0, sLen)) {
|
if (entry.keyLen == sLen && eq(buf, entry.key, entry.keyLen)) {
|
||||||
entry.min = (short) Math.min(entry.min, val);
|
entry.min = (short) Math.min(entry.min, val);
|
||||||
entry.max = (short) Math.max(entry.max, val);
|
entry.max = (short) Math.max(entry.max, val);
|
||||||
entry.sum += val;
|
entry.sum += val;
|
||||||
@ -103,6 +103,23 @@ public class CalculateAverage_parkertimmins {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static boolean eq(byte[] buf, byte[] entryKey, int sLen) {
|
||||||
|
int needed = sLen;
|
||||||
|
for (int offset = 0; offset <= 96; offset += 32) {
|
||||||
|
var a = ByteVector.fromArray(ByteVector.SPECIES_256, buf, offset);
|
||||||
|
var b = ByteVector.fromArray(ByteVector.SPECIES_256, entryKey, offset);
|
||||||
|
int matches = a.eq(b).not().firstTrue();
|
||||||
|
if (needed <= 32) {
|
||||||
|
return matches >= needed;
|
||||||
|
}
|
||||||
|
else if (matches < 32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
needed -= 32;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static long findNextEntryStart(MemorySegment ms, long offset) {
|
static long findNextEntryStart(MemorySegment ms, long offset) {
|
||||||
long curr = offset;
|
long curr = offset;
|
||||||
while (ms.get(ValueLayout.JAVA_BYTE, curr) != '\n') {
|
while (ms.get(ValueLayout.JAVA_BYTE, curr) != '\n') {
|
||||||
@ -112,8 +129,17 @@ public class CalculateAverage_parkertimmins {
|
|||||||
return curr;
|
return curr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static short[] digits10s = { 0, 100, 200, 300, 400, 500, 600, 700, 800, 900 };
|
static short[] digits2s = new short[256];
|
||||||
static short[] digits1s = { 0, 10, 20, 30, 40, 50, 60, 70, 80, 90 };
|
static short[] digits1s = new short[256];
|
||||||
|
static short[] digits0s = new short[256];
|
||||||
|
|
||||||
|
static {
|
||||||
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
digits2s[i + ((int) '0')] = (short) (i * 100);
|
||||||
|
digits1s[i + ((int) '0')] = (short) (i * 10);
|
||||||
|
digits0s[i + ((int) '0')] = (short) i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void processRangeScalar(MemorySegment ms, long start, long end, final OpenHashTable localAgg) {
|
static void processRangeScalar(MemorySegment ms, long start, long end, final OpenHashTable localAgg) {
|
||||||
byte[] buf = new byte[128];
|
byte[] buf = new byte[128];
|
||||||
@ -139,9 +165,10 @@ public class CalculateAverage_parkertimmins {
|
|||||||
boolean neg = ms.get(ValueLayout.JAVA_BYTE, tempIdx) == '-';
|
boolean neg = ms.get(ValueLayout.JAVA_BYTE, tempIdx) == '-';
|
||||||
boolean twoDig = ms.get(ValueLayout.JAVA_BYTE, tempIdx + 1 + (neg ? 1 : 0)) == '.';
|
boolean twoDig = ms.get(ValueLayout.JAVA_BYTE, tempIdx + 1 + (neg ? 1 : 0)) == '.';
|
||||||
int len = 3 + (neg ? 1 : 0) + (twoDig ? 0 : 1);
|
int len = 3 + (neg ? 1 : 0) + (twoDig ? 0 : 1);
|
||||||
int d0 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 1)) - '0';
|
int d0 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 1));
|
||||||
int d1 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 3)) - '0';
|
int d1 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 3));
|
||||||
int base = d0 + digits1s[d1] + (twoDig ? 0 : digits10s[((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 4)) - '0']);
|
int d2 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 4)); // could be - or \n
|
||||||
|
int base = digits0s[d0] + digits1s[d1] + digits2s[d2];
|
||||||
short temp = (short) (neg ? -base : base);
|
short temp = (short) (neg ? -base : base);
|
||||||
|
|
||||||
localAgg.add(buf, sLen, temp, hash);
|
localAgg.add(buf, sLen, temp, hash);
|
||||||
@ -150,100 +177,55 @@ public class CalculateAverage_parkertimmins {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static int hash(byte[] buf, int sLen) {
|
static int hash(byte[] buf, int sLen) {
|
||||||
// TODO find a hash that works directly from byte array
|
int shift = Math.max(0, 8 - sLen) << 3;
|
||||||
// if shorter than 8 chars, mask out upper bits
|
long mask = (~0L) >>> shift;
|
||||||
long mask = sLen < 8 ? -(1L << ((8 - sLen) << 3)) : 0xFFFFFFFFL;
|
long val = ((buf[7] & 0xffL) << 56) | ((buf[6] & 0xffL) << 48) | ((buf[5] & 0xffL) << 40) | ((buf[4] & 0xffL) << 32) | ((buf[3] & 0xffL) << 24)
|
||||||
long val = ((buf[0] & 0xffL) << 56) | ((buf[1] & 0xffL) << 48) | ((buf[2] & 0xffL) << 40) | ((buf[3] & 0xffL) << 32) | ((buf[4] & 0xffL) << 24)
|
| ((buf[2] & 0xffL) << 16) | ((buf[1] & 0xFFL) << 8) | (buf[0] & 0xffL);
|
||||||
| ((buf[5] & 0xffL) << 16) | ((buf[6] & 0xFFL) << 8) | (buf[7] & 0xffL);
|
|
||||||
val &= mask;
|
val &= mask;
|
||||||
|
|
||||||
// also worth trying: https://lemire.me/blog/2015/10/22/faster-hashing-without-effort/
|
|
||||||
// lemire: https://lemire.me/blog/2023/07/14/recognizing-string-prefixes-with-simd-instructions/
|
// lemire: https://lemire.me/blog/2023/07/14/recognizing-string-prefixes-with-simd-instructions/
|
||||||
int hash = (int) (((((val >> 32) ^ val) & 0xffffffffL) * 3523216699L) >> 32);
|
int hash = (int) (((((val >> 32) ^ val) & 0xffffffffL) * 3523216699L) >> 32);
|
||||||
return hash;
|
return hash;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void processRangeSIMD(MemorySegment ms, boolean frontPad, boolean backPad, long start, long end, final OpenHashTable localAgg) {
|
static void processRangeSIMD(MemorySegment ms, boolean isFirst, boolean isLast, long start, long end, final OpenHashTable localAgg) {
|
||||||
byte[] buf = new byte[128];
|
byte[] buf = new byte[128];
|
||||||
|
|
||||||
long curr = frontPad ? findNextEntryStart(ms, start) : start;
|
long curr = isFirst ? start : findNextEntryStart(ms, start);
|
||||||
long limit = end - padding;
|
long limit = isLast ? end - padding : end;
|
||||||
|
|
||||||
var needle = ByteVector.broadcast(ByteVector.SPECIES_256, ';');
|
|
||||||
while (curr < limit) {
|
while (curr < limit) {
|
||||||
|
int nl = 0;
|
||||||
int segStart = 0;
|
for (int offset = 0; offset < 128; offset += 32) {
|
||||||
int sLen;
|
ByteVector section = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, ms, curr + offset, ByteOrder.LITTLE_ENDIAN);
|
||||||
|
section.intoArray(buf, offset);
|
||||||
while (true) {
|
var idx = section.eq((byte) '\n').firstTrue();
|
||||||
var section = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, ms, curr + segStart, ByteOrder.LITTLE_ENDIAN);
|
|
||||||
section.intoArray(buf, segStart);
|
|
||||||
VectorMask<Byte> matches = section.compare(VectorOperators.EQ, needle);
|
|
||||||
int idx = matches.firstTrue();
|
|
||||||
if (idx != 32) {
|
if (idx != 32) {
|
||||||
sLen = segStart + idx;
|
nl = offset + idx;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
segStart += 32;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int hash = hash(buf, sLen);
|
int nl1 = buf[nl - 1];
|
||||||
|
int nl3 = buf[nl - 3];
|
||||||
curr += sLen;
|
int nl4 = buf[nl - 4];
|
||||||
curr++; // semicolon
|
int nl5 = buf[nl - 5];
|
||||||
|
int base = (nl1 - '0') + 10 * (nl3 - '0') + digits2s[nl4];
|
||||||
long tempIdx = curr;
|
boolean neg = nl4 == '-' || (nl4 != ';' && nl5 == '-');
|
||||||
boolean neg = ms.get(ValueLayout.JAVA_BYTE, tempIdx) == '-';
|
|
||||||
boolean twoDig = ms.get(ValueLayout.JAVA_BYTE, tempIdx + 1 + (neg ? 1 : 0)) == '.';
|
|
||||||
int len = 3 + (neg ? 1 : 0) + (twoDig ? 0 : 1);
|
|
||||||
int d0 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 1)) - '0';
|
|
||||||
int d1 = ((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 3)) - '0';
|
|
||||||
int base = d0 + digits1s[d1] + (twoDig ? 0 : digits10s[((char) ms.get(ValueLayout.JAVA_BYTE, tempIdx + len - 4)) - '0']);
|
|
||||||
short temp = (short) (neg ? -base : base);
|
short temp = (short) (neg ? -base : base);
|
||||||
|
int tempLen = 4 + (neg ? 1 : 0) + (base >= 100 ? 1 : 0);
|
||||||
|
int semi = nl - tempLen;
|
||||||
|
|
||||||
localAgg.add(buf, sLen, temp, hash);
|
int hash = hash(buf, semi);
|
||||||
curr = tempIdx + len + 1;
|
localAgg.add(buf, semi, temp, hash);
|
||||||
|
curr += (nl + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// last batch is near end of file, process without SIMD to avoid out-of-bounds
|
// last batch is near end of file, process without SIMD to avoid out-of-bounds
|
||||||
if (!backPad) {
|
if (isLast) {
|
||||||
processRangeScalar(ms, curr, end, localAgg);
|
processRangeScalar(ms, curr, end, localAgg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* For debugging issues with hash function
|
|
||||||
*/
|
|
||||||
static void checkHashDistributionQuality(ArrayList<OpenHashTable> localAggs) {
|
|
||||||
HashSet<Integer> uniquesHashValues = new HashSet<Integer>();
|
|
||||||
HashSet<String> uniqueCities = new HashSet<String>();
|
|
||||||
HashMap<String, HashSet<Integer>> cityToHash = new HashMap<>();
|
|
||||||
|
|
||||||
for (var agg : localAggs) {
|
|
||||||
for (OpenHashTable.Entry entry : agg.entries) {
|
|
||||||
if (entry == null) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
uniquesHashValues.add(entry.hash);
|
|
||||||
String station = new String(entry.key, StandardCharsets.UTF_8); // for UTF-8 encoding
|
|
||||||
uniqueCities.add(station);
|
|
||||||
|
|
||||||
if (!cityToHash.containsKey(station)) {
|
|
||||||
cityToHash.put(station, new HashSet<>());
|
|
||||||
}
|
|
||||||
cityToHash.get(station).add(entry.hash);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (var pair : cityToHash.entrySet()) {
|
|
||||||
if (pair.getValue().size() > 1) {
|
|
||||||
System.err.println("multiple hashes: " + pair.getKey() + " " + pair.getValue());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
System.err.println("Unique stations: " + uniqueCities.size() + ", unique hash values: " + uniquesHashValues.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Combine thread local values
|
* Combine thread local values
|
||||||
*/
|
*/
|
||||||
@ -254,7 +236,7 @@ public class CalculateAverage_parkertimmins {
|
|||||||
if (entry == null) {
|
if (entry == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
String station = new String(entry.key, StandardCharsets.UTF_8); // for UTF-8 encoding
|
String station = new String(entry.key, 0, entry.keyLen, StandardCharsets.UTF_8); // for UTF-8 encoding
|
||||||
var currentVal = global.get(station);
|
var currentVal = global.get(station);
|
||||||
if (currentVal != null) {
|
if (currentVal != null) {
|
||||||
currentVal.merge(entry);
|
currentVal.merge(entry);
|
||||||
@ -267,8 +249,6 @@ public class CalculateAverage_parkertimmins {
|
|||||||
return global;
|
return global;
|
||||||
}
|
}
|
||||||
|
|
||||||
static final long batchSize = 10_000_000;
|
|
||||||
|
|
||||||
static final int padding = 200; // max entry size is 107ish == 100 (station) + 1 (semicolon) + 5 (temp, eg -99.9) + 1 (newline)
|
static final int padding = 200; // max entry size is 107ish == 100 (station) + 1 (semicolon) + 5 (temp, eg -99.9) + 1 (newline)
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException, InterruptedException {
|
public static void main(String[] args) throws IOException, InterruptedException {
|
||||||
@ -277,7 +257,10 @@ public class CalculateAverage_parkertimmins {
|
|||||||
|
|
||||||
int numThreads = Runtime.getRuntime().availableProcessors();
|
int numThreads = Runtime.getRuntime().availableProcessors();
|
||||||
|
|
||||||
|
final long batchSize = 10_000_000;
|
||||||
|
|
||||||
final long fileSize = channel.size();
|
final long fileSize = channel.size();
|
||||||
|
// final long batchSize = fileSize / numThreads + 1;
|
||||||
final MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global());
|
final MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global());
|
||||||
final ArrayList<OpenHashTable> localAggs = new ArrayList<>(numThreads);
|
final ArrayList<OpenHashTable> localAggs = new ArrayList<>(numThreads);
|
||||||
Thread[] threads = new Thread[numThreads];
|
Thread[] threads = new Thread[numThreads];
|
||||||
@ -299,11 +282,9 @@ public class CalculateAverage_parkertimmins {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
final long endBatch = Math.min(startBatch + batchSize, fileSize);
|
final long endBatch = Math.min(startBatch + batchSize, fileSize);
|
||||||
final boolean first = startBatch == 0;
|
final boolean isFirstBatch = startBatch == 0;
|
||||||
final boolean frontPad = !first;
|
final boolean isLastBatch = endBatch == fileSize;
|
||||||
final boolean last = endBatch == fileSize;
|
processRangeSIMD(ms, isFirstBatch, isLastBatch, startBatch, endBatch, localAgg);
|
||||||
final boolean backPad = !last;
|
|
||||||
processRangeSIMD(ms, frontPad, backPad, startBatch, endBatch, localAgg);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user