Use simd for name comparison (#568)
Co-authored-by: Ian Preston <ianopolous@protonmail.com>
This commit is contained in:
parent
f7febea2f6
commit
8bae1b8781
@ -34,7 +34,7 @@ import static java.lang.foreign.ValueLayout.*;
|
|||||||
/* A fast implementation with no unsafe.
|
/* A fast implementation with no unsafe.
|
||||||
* Features:
|
* Features:
|
||||||
* * memory mapped file using preview Arena FFI
|
* * memory mapped file using preview Arena FFI
|
||||||
* * semicolon finding using incubator vector api
|
* * semicolon finding and name comparison using incubator vector api
|
||||||
* * read chunks in parallel
|
* * read chunks in parallel
|
||||||
* * minimise allocation
|
* * minimise allocation
|
||||||
* * no unsafe
|
* * no unsafe
|
||||||
@ -80,12 +80,11 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
System.out.println(merged);
|
System.out.println(merged);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static boolean matchingStationBytes(long start, long end, int offset, MemorySegment buffer, Stat existing) {
|
public static boolean matchingStationBytes(long start, long end, MemorySegment buffer, Stat existing) {
|
||||||
int len = (int) (end - start);
|
for (int index = 0; index < end - start; index += BYTE_SPECIES.vectorByteSize()) {
|
||||||
if (len != existing.name.length)
|
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, start + index, ByteOrder.nativeOrder(), BYTE_SPECIES.indexInRange(start + index, end));
|
||||||
return false;
|
ByteVector found = ByteVector.fromArray(BYTE_SPECIES, existing.name, index);
|
||||||
for (int i = offset; i < len; i++) {
|
if (!found.eq(line).allTrue())
|
||||||
if (existing.name[i] != buffer.get(JAVA_BYTE, offset + start++))
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -98,21 +97,19 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
return (finalHash & (len - 1));
|
return (finalHash & (len - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Stat parseStation(long start, long end, long first8, long second8,
|
public static Stat createStation(long start, long end, MemorySegment buffer) {
|
||||||
MemorySegment buffer) {
|
|
||||||
byte[] stationBuffer = new byte[(int) (end - start)];
|
byte[] stationBuffer = new byte[(int) (end - start)];
|
||||||
for (long off = start; off < end; off++)
|
for (long off = start; off < end; off++)
|
||||||
stationBuffer[(int) (off - start)] = buffer.get(JAVA_BYTE, off);
|
stationBuffer[(int) (off - start)] = buffer.get(JAVA_BYTE, off);
|
||||||
return new Stat(stationBuffer, first8, second8);
|
return new Stat(stationBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Stat dedupeStation(long start, long end, long hash, long first8, long second8,
|
public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, List<List<Stat>> stations) {
|
||||||
MemorySegment buffer, List<List<Stat>> stations) {
|
|
||||||
int index = hashToIndex(hash, MAX_STATIONS);
|
int index = hashToIndex(hash, MAX_STATIONS);
|
||||||
List<Stat> matches = stations.get(index);
|
List<Stat> matches = stations.get(index);
|
||||||
if (matches == null) {
|
if (matches == null) {
|
||||||
List<Stat> value = new ArrayList<>();
|
List<Stat> value = new ArrayList<>();
|
||||||
Stat res = parseStation(start, end, first8, second8, buffer);
|
Stat res = createStation(start, end, buffer);
|
||||||
value.add(res);
|
value.add(res);
|
||||||
stations.set(index, value);
|
stations.set(index, value);
|
||||||
return res;
|
return res;
|
||||||
@ -120,54 +117,10 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
else {
|
else {
|
||||||
for (int i = 0; i < matches.size(); i++) {
|
for (int i = 0; i < matches.size(); i++) {
|
||||||
Stat s = matches.get(i);
|
Stat s = matches.get(i);
|
||||||
if (first8 == s.first8 && second8 == s.second8 && matchingStationBytes(start, end, 16, buffer, s))
|
if (matchingStationBytes(start, end, buffer, s))
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
Stat res = parseStation(start, end, first8, second8, buffer);
|
Stat res = createStation(start, end, buffer);
|
||||||
matches.add(res);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Stat dedupeStation8(long start, long end, long hash, long first8, MemorySegment buffer, List<List<Stat>> stations) {
|
|
||||||
int index = hashToIndex(hash, MAX_STATIONS);
|
|
||||||
List<Stat> matches = stations.get(index);
|
|
||||||
if (matches == null) {
|
|
||||||
List<Stat> value = new ArrayList<>();
|
|
||||||
Stat station = parseStation(start, end, first8, 0, buffer);
|
|
||||||
value.add(station);
|
|
||||||
stations.set(index, value);
|
|
||||||
return station;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
for (int i = 0; i < matches.size(); i++) {
|
|
||||||
Stat s = matches.get(i);
|
|
||||||
if (first8 == s.first8 && s.name.length <= 8)
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
Stat station = parseStation(start, end, first8, 0, buffer);
|
|
||||||
matches.add(station);
|
|
||||||
return station;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Stat dedupeStation16(long start, long end, long hash, long first8, long second8, MemorySegment buffer, List<List<Stat>> stations) {
|
|
||||||
int index = hashToIndex(hash, MAX_STATIONS);
|
|
||||||
List<Stat> matches = stations.get(index);
|
|
||||||
if (matches == null) {
|
|
||||||
List<Stat> value = new ArrayList<>();
|
|
||||||
Stat res = parseStation(start, end, first8, second8, buffer);
|
|
||||||
value.add(res);
|
|
||||||
stations.set(index, value);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
for (int i = 0; i < matches.size(); i++) {
|
|
||||||
Stat s = matches.get(i);
|
|
||||||
if (first8 == s.first8 && second8 == s.second8 && s.name.length <= 16)
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
Stat res = parseStation(start, end, first8, second8, buffer);
|
|
||||||
matches.add(res);
|
matches.add(res);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -181,32 +134,22 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
|
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
|
||||||
int keySize = line.compare(VectorOperators.EQ, ';').firstTrue();
|
int keySize = line.compare(VectorOperators.EQ, ';').firstTrue();
|
||||||
|
|
||||||
|
long first8 = buffer.get(LONG_LAYOUT, lineStart);
|
||||||
if (keySize == BYTE_SPECIES.vectorByteSize()) {
|
if (keySize == BYTE_SPECIES.vectorByteSize()) {
|
||||||
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
|
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
|
||||||
keySize++;
|
keySize++;
|
||||||
}
|
}
|
||||||
long first8 = buffer.get(LONG_LAYOUT, lineStart);
|
|
||||||
if (keySize < 8)
|
|
||||||
return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations);
|
|
||||||
long second8 = buffer.get(LONG_LAYOUT, lineStart + 8);
|
long second8 = buffer.get(LONG_LAYOUT, lineStart + 8);
|
||||||
if (keySize < 16)
|
|
||||||
return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations);
|
|
||||||
long hash = first8 ^ second8; // todo include other bytes
|
long hash = first8 ^ second8; // todo include other bytes
|
||||||
return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations);
|
return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
|
||||||
}
|
}
|
||||||
|
|
||||||
long first8 = buffer.get(LONG_LAYOUT, lineStart);
|
|
||||||
if (keySize <= 8) {
|
if (keySize <= 8) {
|
||||||
first8 = maskHighBytes(first8, keySize & 0x07);
|
first8 = maskHighBytes(first8, keySize & 0x07);
|
||||||
return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations);
|
|
||||||
}
|
|
||||||
long second8 = buffer.get(LONG_LAYOUT, lineStart + 8);
|
|
||||||
if (keySize < 16) {
|
|
||||||
second8 = maskHighBytes(second8, keySize & 0x07);
|
|
||||||
return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations);
|
|
||||||
}
|
}
|
||||||
|
long second8 = keySize <= 8 ? 0 : maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
|
||||||
long hash = first8 ^ second8; // todo include later bytes
|
long hash = first8 ^ second8; // todo include later bytes
|
||||||
return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations);
|
return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int getDot(long d) {
|
public static int getDot(long d) {
|
||||||
@ -261,13 +204,10 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
// in the inner loop (reducing branches)
|
// in the inner loop (reducing branches)
|
||||||
// We need at least the vector lane size bytes back
|
// We need at least the vector lane size bytes back
|
||||||
if (endByte == buffer.byteSize()) {
|
if (endByte == buffer.byteSize()) {
|
||||||
endByte -= 1; // skip final new line
|
|
||||||
// reverse at least vector lane width
|
// reverse at least vector lane width
|
||||||
while (endByte > 0 && buffer.byteSize() - endByte < BYTE_SPECIES.vectorByteSize()) {
|
endByte = Math.max(buffer.byteSize() - BYTE_SPECIES.vectorByteSize(), 0);
|
||||||
endByte--;
|
|
||||||
while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
|
while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
|
||||||
endByte--;
|
endByte--;
|
||||||
}
|
|
||||||
|
|
||||||
if (endByte > 0)
|
if (endByte > 0)
|
||||||
endByte++;
|
endByte++;
|
||||||
@ -278,28 +218,33 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
int index = 0;
|
int index = 0;
|
||||||
while (endByte + index < buffer.byteSize()) {
|
while (endByte + index < buffer.byteSize()) {
|
||||||
Stat station = parseStation(index, end, stations);
|
Stat station = parseStation(index, end, stations);
|
||||||
index = (int) processTemperature(index + station.name.length + 1, end, station);
|
index = (int) processTemperature(index + station.namelen + 1, end, station);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
innerloop(startByte, endByte, buffer, stations);
|
||||||
|
return stations;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void innerloop(long startByte, long endByte, MemorySegment buffer, List<List<Stat>> stations) {
|
||||||
while (startByte < endByte) {
|
while (startByte < endByte) {
|
||||||
Stat station = parseStation(startByte, buffer, stations);
|
Stat station = parseStation(startByte, buffer, stations);
|
||||||
startByte = processTemperature(startByte + station.name.length + 1, buffer, station);
|
startByte = processTemperature(startByte + station.namelen + 1, buffer, station);
|
||||||
}
|
}
|
||||||
return stations;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Stat {
|
public static class Stat {
|
||||||
final byte[] name;
|
final byte[] name;
|
||||||
|
final int namelen;
|
||||||
int count = 0;
|
int count = 0;
|
||||||
short min = Short.MAX_VALUE, max = Short.MIN_VALUE;
|
short min = Short.MAX_VALUE, max = Short.MIN_VALUE;
|
||||||
long total = 0;
|
long total = 0;
|
||||||
final long first8, second8;
|
|
||||||
|
|
||||||
public Stat(byte[] name, long first8, long second8) {
|
public Stat(byte[] name) {
|
||||||
this.name = name;
|
int vecSize = BYTE_SPECIES.vectorByteSize();
|
||||||
this.first8 = first8;
|
int arrayLen = (name.length + vecSize - 1) / vecSize * vecSize;
|
||||||
this.second8 = second8;
|
this.name = Arrays.copyOfRange(name, 0, arrayLen);
|
||||||
|
this.namelen = name.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void add(short value) {
|
public void add(short value) {
|
||||||
@ -326,7 +271,7 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public String name() {
|
public String name() {
|
||||||
return new String(name);
|
return new String(Arrays.copyOfRange(name, 0, namelen));
|
||||||
}
|
}
|
||||||
|
|
||||||
public String toString() {
|
public String toString() {
|
||||||
|
Loading…
Reference in New Issue
Block a user