Process two consecutive lines at a time (#651)
Use a better hash function Don't return index from temperature parsing extra JVM args Co-authored-by: Ian Preston <ianopolous@protonmail.com>
This commit is contained in:
parent
ab2a9a6fe5
commit
0c5c22882b
@ -16,4 +16,6 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector"
|
JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector"
|
||||||
|
#-Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 -XX:-UseTransparentHugePages"
|
||||||
|
|
||||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ianopolousfast
|
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ianopolousfast
|
||||||
|
@ -19,7 +19,6 @@ import jdk.incubator.vector.ByteVector;
|
|||||||
import jdk.incubator.vector.VectorOperators;
|
import jdk.incubator.vector.VectorOperators;
|
||||||
import jdk.incubator.vector.VectorSpecies;
|
import jdk.incubator.vector.VectorSpecies;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.lang.foreign.Arena;
|
import java.lang.foreign.Arena;
|
||||||
import java.lang.foreign.MemorySegment;
|
import java.lang.foreign.MemorySegment;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
@ -39,10 +38,7 @@ import static java.lang.foreign.ValueLayout.*;
|
|||||||
* * read chunks in parallel
|
* * read chunks in parallel
|
||||||
* * minimise allocation
|
* * minimise allocation
|
||||||
* * no unsafe
|
* * no unsafe
|
||||||
*
|
* * process multiple lines in each thread for better ILP
|
||||||
* Timings on 4 core i7-7500U CPU @ 2.70GHz:
|
|
||||||
* average_baseline: 4m48s
|
|
||||||
* ianopolous: 13.8s
|
|
||||||
*/
|
*/
|
||||||
public class CalculateAverage_ianopolousfast {
|
public class CalculateAverage_ianopolousfast {
|
||||||
|
|
||||||
@ -91,11 +87,22 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int hashToIndex(long hash, int len) {
|
private static final int GOLDEN_RATIO = 0x9E3779B9;
|
||||||
// From Thomas Wuerthinger's entry
|
private static final int HASH_LROTATE = 5;
|
||||||
int hashAsInt = (int) (hash ^ (hash >>> 28));
|
|
||||||
int finalHash = (hashAsInt ^ (hashAsInt >>> 15));
|
// hash from giovannicuccu
|
||||||
return (finalHash & (len - 1));
|
private static int hash(MemorySegment memorySegment, long start, int len) {
|
||||||
|
int x;
|
||||||
|
int y;
|
||||||
|
if (len >= Integer.BYTES) {
|
||||||
|
x = memorySegment.get(JAVA_INT_UNALIGNED, start);
|
||||||
|
y = memorySegment.get(JAVA_INT_UNALIGNED, start + len - Integer.BYTES);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
x = memorySegment.get(JAVA_BYTE, start);
|
||||||
|
y = memorySegment.get(JAVA_BYTE, start + len - Byte.BYTES);
|
||||||
|
}
|
||||||
|
return (Integer.rotateLeft(x * GOLDEN_RATIO, HASH_LROTATE) ^ y) * GOLDEN_RATIO;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Stat createStation(long start, long end, MemorySegment buffer) {
|
public static Stat createStation(long start, long end, MemorySegment buffer) {
|
||||||
@ -105,8 +112,9 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
return new Stat(stationBuffer);
|
return new Stat(stationBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, Stat[] stations) {
|
public static Stat dedupeStation(long start, long end, MemorySegment buffer, Stat[] stations) {
|
||||||
int index = hashToIndex(hash, MAX_STATIONS);
|
int hash = hash(buffer, start, (int) (end - start));
|
||||||
|
int index = hash & (MAX_STATIONS - 1);
|
||||||
Stat match = stations[index];
|
Stat match = stations[index];
|
||||||
while (match != null) {
|
while (match != null) {
|
||||||
if (matchingStationBytes(start, end, buffer, match))
|
if (matchingStationBytes(start, end, buffer, match))
|
||||||
@ -119,37 +127,11 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static long maskHighBytes(long d, int nbytes) {
|
|
||||||
return d & (-1L << ((8 - nbytes) * 8));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Stat parseStation(long lineStart, MemorySegment buffer, Stat[] stations) {
|
|
||||||
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
|
|
||||||
int keySize = line.compare(VectorOperators.EQ, ';').firstTrue();
|
|
||||||
|
|
||||||
long first8 = buffer.get(LONG_LAYOUT, lineStart);
|
|
||||||
long second8 = 0;
|
|
||||||
if (keySize <= 8) {
|
|
||||||
first8 = maskHighBytes(first8, keySize & 0x07);
|
|
||||||
}
|
|
||||||
else if (keySize < 16) {
|
|
||||||
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
|
|
||||||
}
|
|
||||||
else if (keySize == BYTE_SPECIES.vectorByteSize()) {
|
|
||||||
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
|
|
||||||
keySize++;
|
|
||||||
}
|
|
||||||
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
|
|
||||||
}
|
|
||||||
long hash = first8 ^ second8; // todo include later bytes
|
|
||||||
return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static short getMinus(long d) {
|
public static short getMinus(long d) {
|
||||||
return ((d & 0xff00000000000000L) ^ 0x2d00000000000000L) != 0 ? 0 : (short) -1;
|
return ((d & 0xff00000000000000L) ^ 0x2d00000000000000L) != 0 ? 0 : (short) -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static long processTemperature(long lineSplit, int size, MemorySegment buffer, Stat station) {
|
public static void processTemperature(long lineSplit, int size, MemorySegment buffer, Stat station) {
|
||||||
long d = buffer.get(LONG_LAYOUT, lineSplit);
|
long d = buffer.get(LONG_LAYOUT, lineSplit);
|
||||||
// negative is either 0 or -1
|
// negative is either 0 or -1
|
||||||
short negative = getMinus(d);
|
short negative = getMinus(d);
|
||||||
@ -162,10 +144,9 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
100 * (((byte) (d >> 24)) - '0'));
|
100 * (((byte) (d >> 24)) - '0'));
|
||||||
temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty
|
temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty
|
||||||
station.add(temperature);
|
station.add(temperature);
|
||||||
return lineSplit + size + 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static long parseLine(long lineStart, MemorySegment buffer, Stat[] stations) {
|
private static int lineSize(long lineStart, MemorySegment buffer) {
|
||||||
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
|
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
|
||||||
int lineSize = line.compare(VectorOperators.EQ, '\n').firstTrue();
|
int lineSize = line.compare(VectorOperators.EQ, '\n').firstTrue();
|
||||||
int index = lineSize;
|
int index = lineSize;
|
||||||
@ -174,33 +155,19 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, '\n').firstTrue();
|
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, '\n').firstTrue();
|
||||||
lineSize += index;
|
lineSize += index;
|
||||||
}
|
}
|
||||||
int keySize = lineSize - 6 + ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize - 6,
|
return lineSize;
|
||||||
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue();
|
|
||||||
|
|
||||||
long first8 = buffer.get(LONG_LAYOUT, lineStart);
|
|
||||||
long second8 = 0;
|
|
||||||
if (keySize <= 8) {
|
|
||||||
first8 = maskHighBytes(first8, keySize & 0x07);
|
|
||||||
}
|
|
||||||
else if (keySize < 16) {
|
|
||||||
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
|
|
||||||
}
|
|
||||||
else if (keySize == BYTE_SPECIES.vectorByteSize()) {
|
|
||||||
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
|
|
||||||
keySize++;
|
|
||||||
}
|
|
||||||
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
|
|
||||||
}
|
|
||||||
long hash = first8 ^ second8; // todo include later bytes
|
|
||||||
Stat station = dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
|
|
||||||
return processTemperature(lineStart + keySize + 1, lineSize - keySize - 1, buffer, station);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Stat[] parseStats(long startByte, long endByte, MemorySegment buffer) {
|
private static int keySize(int lineSize, long lineStart, MemorySegment buffer) {
|
||||||
|
return lineSize - 6 + ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize - 6,
|
||||||
|
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Stat[] parseStats(long start1, long end2, MemorySegment buffer) {
|
||||||
// read first partial line
|
// read first partial line
|
||||||
if (startByte > 0) {
|
if (start1 > 0) {
|
||||||
for (int i = 0; i < MAX_LINE_LENGTH; i++) {
|
for (int i = 0; i < MAX_LINE_LENGTH; i++) {
|
||||||
byte b = buffer.get(JAVA_BYTE, startByte++);
|
byte b = buffer.get(JAVA_BYTE, start1++);
|
||||||
if (b == '\n') {
|
if (b == '\n') {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -213,38 +180,47 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
// this allows us to not worry about reading beyond the end
|
// this allows us to not worry about reading beyond the end
|
||||||
// in the inner loop (reducing branches)
|
// in the inner loop (reducing branches)
|
||||||
// We need at least the vector lane size bytes back
|
// We need at least the vector lane size bytes back
|
||||||
if (endByte == buffer.byteSize()) {
|
if (end2 == buffer.byteSize()) {
|
||||||
// reverse at least vector lane width
|
// reverse at least vector lane width
|
||||||
endByte = Math.max(buffer.byteSize() - BYTE_SPECIES.vectorByteSize(), 0);
|
end2 = Math.max(buffer.byteSize() - 2 * BYTE_SPECIES.vectorByteSize(), 0);
|
||||||
while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
|
while (end2 > 0 && buffer.get(JAVA_BYTE, end2) != '\n')
|
||||||
endByte--;
|
end2--;
|
||||||
|
|
||||||
if (endByte > 0)
|
if (end2 > 0)
|
||||||
endByte++;
|
end2++;
|
||||||
// copy into a larger buffer to avoid reading off end
|
// copy into a larger buffer to avoid reading off end
|
||||||
MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + BYTE_SPECIES.vectorByteSize());
|
MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 2 * BYTE_SPECIES.vectorByteSize());
|
||||||
for (long i = endByte; i < buffer.byteSize(); i++)
|
for (long i = end2; i < buffer.byteSize(); i++)
|
||||||
end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i));
|
end.set(JAVA_BYTE, i - end2, buffer.get(JAVA_BYTE, i));
|
||||||
int index = 0;
|
int index = 0;
|
||||||
while (endByte + index < buffer.byteSize()) {
|
while (end2 + index < buffer.byteSize()) {
|
||||||
Stat station = parseStation(index, end, stations);
|
int lineSize1 = lineSize(index, end);
|
||||||
int tempSize = 3;
|
int semiSearchStart = index + Math.max(0, lineSize1 - 6);
|
||||||
if (end.get(JAVA_BYTE, index + station.namelen + 5) == '\n')
|
int keySize1 = semiSearchStart - index + ByteVector.fromMemorySegment(BYTE_SPECIES, end, semiSearchStart,
|
||||||
tempSize = 4;
|
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue();
|
||||||
if (end.get(JAVA_BYTE, index + station.namelen + 6) == '\n')
|
Stat station1 = dedupeStation(index, index + keySize1, end, stations);
|
||||||
tempSize = 5;
|
processTemperature(index + keySize1 + 1, lineSize1 - keySize1 - 1, end, station1);
|
||||||
index = (int) processTemperature(index + station.namelen + 1, tempSize, end, station);
|
index += lineSize1 + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
innerloop(startByte, endByte, buffer, stations);
|
while (start1 < end2) {
|
||||||
return stations;
|
int lineSize1 = lineSize(start1, buffer);
|
||||||
}
|
long start2 = start1 + lineSize1 + 1;
|
||||||
|
int lineSize2 = start2 < end2 ? lineSize(start2, buffer) : 0;
|
||||||
private static void innerloop(long startByte, long endByte, MemorySegment buffer, Stat[] stations) {
|
int keySize1 = keySize(lineSize1, start1, buffer);
|
||||||
while (startByte < endByte) {
|
int keySize2 = keySize(lineSize2, start2, buffer);
|
||||||
startByte = parseLine(startByte, buffer, stations);
|
Stat station1 = dedupeStation(start1, start1 + keySize1, buffer, stations);
|
||||||
|
processTemperature(start1 + keySize1 + 1, lineSize1 - keySize1 - 1, buffer, station1);
|
||||||
|
if (start2 < end2) {
|
||||||
|
Stat station2 = dedupeStation(start2, start2 + keySize2, buffer, stations);
|
||||||
|
processTemperature(start2 + keySize2 + 1, lineSize2 - keySize2 - 1, buffer, station2);
|
||||||
|
start1 = start2 + lineSize2 + 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
start1 += lineSize1 + 1;
|
||||||
}
|
}
|
||||||
|
return stations;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Stat {
|
public static class Stat {
|
||||||
|
Loading…
Reference in New Issue
Block a user