Introducing the vector api. 1s faster on 4 core i7 (#506)
Co-authored-by: Ian Preston <ianopolous@protonmail.com>
This commit is contained in:
parent
114ba76d20
commit
062f2bbecf
@ -15,5 +15,5 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
JAVA_OPTS="--enable-preview"
|
JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector"
|
||||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ianopolousfast
|
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ianopolousfast
|
||||||
|
@ -15,6 +15,10 @@
|
|||||||
*/
|
*/
|
||||||
package dev.morling.onebrc;
|
package dev.morling.onebrc;
|
||||||
|
|
||||||
|
import jdk.incubator.vector.ByteVector;
|
||||||
|
import jdk.incubator.vector.VectorOperators;
|
||||||
|
import jdk.incubator.vector.VectorSpecies;
|
||||||
|
|
||||||
import java.lang.foreign.Arena;
|
import java.lang.foreign.Arena;
|
||||||
import java.lang.foreign.MemorySegment;
|
import java.lang.foreign.MemorySegment;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
@ -30,19 +34,23 @@ import static java.lang.foreign.ValueLayout.*;
|
|||||||
/* A fast implementation with no unsafe.
|
/* A fast implementation with no unsafe.
|
||||||
* Features:
|
* Features:
|
||||||
* * memory mapped file using preview Arena FFI
|
* * memory mapped file using preview Arena FFI
|
||||||
|
* * semicolon finding using incubator vector api
|
||||||
* * read chunks in parallel
|
* * read chunks in parallel
|
||||||
* * minimise allocation
|
* * minimise allocation
|
||||||
* * no unsafe
|
* * no unsafe
|
||||||
*
|
*
|
||||||
* Timings on 4 core i7-7500U CPU @ 2.70GHz:
|
* Timings on 4 core i7-7500U CPU @ 2.70GHz:
|
||||||
* average_baseline: 4m48s
|
* average_baseline: 4m48s
|
||||||
* ianopolous: 16s
|
* ianopolous: 15s
|
||||||
*/
|
*/
|
||||||
public class CalculateAverage_ianopolousfast {
|
public class CalculateAverage_ianopolousfast {
|
||||||
|
|
||||||
public static final int MAX_LINE_LENGTH = 107;
|
public static final int MAX_LINE_LENGTH = 107;
|
||||||
public static final int MAX_STATIONS = 1 << 14;
|
public static final int MAX_STATIONS = 1 << 14;
|
||||||
private static final OfLong LONG_LAYOUT = JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
|
private static final OfLong LONG_LAYOUT = JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
|
||||||
|
private static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_PREFERRED.length() >= 32
|
||||||
|
? ByteVector.SPECIES_256
|
||||||
|
: ByteVector.SPECIES_128;
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
Arena arena = Arena.global();
|
Arena arena = Arena.global();
|
||||||
@ -165,58 +173,40 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static long hasSemicolon(long d) {
|
|
||||||
// from Hacker's Delight page 92
|
|
||||||
d = d ^ 0x3b3b3b3b3b3b3b3bL;
|
|
||||||
long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL;
|
|
||||||
return ~(y | d | 0x7f7f7f7f7f7f7f7fL);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int getSemicolonIndex(long y) {
|
|
||||||
// from Hacker's Delight page 92
|
|
||||||
return Long.numberOfLeadingZeros(y) >> 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
static long maskHighBytes(long d, int nbytes) {
|
static long maskHighBytes(long d, int nbytes) {
|
||||||
return d & (-1L << ((8 - nbytes) * 8));
|
return d & (-1L << ((8 - nbytes) * 8));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Stat parseStation(long lineStart, MemorySegment buffer, List<List<Stat>> stations) {
|
public static Stat parseStation(long lineStart, MemorySegment buffer, List<List<Stat>> stations) {
|
||||||
// find semicolon and update hash as we go, reading a long at a time
|
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
|
||||||
long d = buffer.get(LONG_LAYOUT, lineStart);
|
int keySize = line.compare(VectorOperators.EQ, ';').firstTrue();
|
||||||
long hasSemi = hasSemicolon(d);
|
|
||||||
if (hasSemi != 0) {
|
|
||||||
int semiIndex = getSemicolonIndex(hasSemi);
|
|
||||||
d = maskHighBytes(d, semiIndex);
|
|
||||||
return dedupeStation8(lineStart, lineStart + semiIndex, d, d, buffer, stations);
|
|
||||||
}
|
|
||||||
long first8 = d;
|
|
||||||
long hash = d;
|
|
||||||
|
|
||||||
d = buffer.get(LONG_LAYOUT, lineStart + 8);
|
if (keySize == BYTE_SPECIES.vectorByteSize()) {
|
||||||
hasSemi = hasSemicolon(d);
|
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
|
||||||
if (hasSemi != 0) {
|
keySize++;
|
||||||
int semiIndex = getSemicolonIndex(hasSemi);
|
}
|
||||||
if (semiIndex == 0)
|
long first8 = buffer.get(LONG_LAYOUT, lineStart);
|
||||||
return dedupeStation8(lineStart, lineStart + 8, first8, first8, buffer, stations);
|
if (keySize < 8)
|
||||||
d = maskHighBytes(d, semiIndex);
|
return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations);
|
||||||
return dedupeStation16(lineStart, lineStart + 8 + semiIndex, first8 ^ d, first8, d, buffer, stations);
|
long second8 = buffer.get(LONG_LAYOUT, lineStart + 8);
|
||||||
|
if (keySize < 16)
|
||||||
|
return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations);
|
||||||
|
long hash = first8 ^ second8; // todo include other bytes
|
||||||
|
return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations);
|
||||||
}
|
}
|
||||||
|
|
||||||
int index = 8;
|
long first8 = buffer.get(LONG_LAYOUT, lineStart);
|
||||||
long second8 = d;
|
if (keySize <= 8) {
|
||||||
while (hasSemi == 0) {
|
first8 = maskHighBytes(first8, keySize & 0x07);
|
||||||
hash = hash ^ d;
|
return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations);
|
||||||
index += 8;
|
|
||||||
d = buffer.get(LONG_LAYOUT, lineStart + index);
|
|
||||||
hasSemi = hasSemicolon(d);
|
|
||||||
}
|
}
|
||||||
int semiIndex = getSemicolonIndex(hasSemi);
|
long second8 = buffer.get(LONG_LAYOUT, lineStart + 8);
|
||||||
d = maskHighBytes(d, semiIndex);
|
if (keySize < 16) {
|
||||||
if (semiIndex > 0) {
|
second8 = maskHighBytes(second8, keySize & 0x07);
|
||||||
hash = hash ^ d;
|
return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations);
|
||||||
}
|
}
|
||||||
return dedupeStation(lineStart, lineStart + index + semiIndex, hash, first8, second8, buffer, stations);
|
long hash = first8 ^ second8; // todo include later bytes
|
||||||
|
return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int getDot(long d) {
|
public static int getDot(long d) {
|
||||||
@ -266,24 +256,30 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
for (int i = 0; i < MAX_STATIONS; i++)
|
for (int i = 0; i < MAX_STATIONS; i++)
|
||||||
stations.add(null);
|
stations.add(null);
|
||||||
|
|
||||||
// Handle reading the very last line in the file
|
// Handle reading the very last few lines in the file
|
||||||
// this allows us to not worry about reading a long beyond the end
|
// this allows us to not worry about reading beyond the end
|
||||||
// in the inner loop (reducing branches)
|
// in the inner loop (reducing branches)
|
||||||
// We only need to read one because the min record size is 6 bytes
|
// We need at least the vector lane size bytes back
|
||||||
// so 2nd last record must be > 8 from end
|
|
||||||
if (endByte == buffer.byteSize()) {
|
if (endByte == buffer.byteSize()) {
|
||||||
endByte -= 2; // skip final new line
|
endByte -= 1; // skip final new line
|
||||||
while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
|
// reverse at least vector lane width
|
||||||
|
while (endByte > 0 && buffer.byteSize() - endByte < BYTE_SPECIES.vectorByteSize()) {
|
||||||
endByte--;
|
endByte--;
|
||||||
|
while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
|
||||||
|
endByte--;
|
||||||
|
}
|
||||||
|
|
||||||
if (endByte > 0)
|
if (endByte > 0)
|
||||||
endByte++;
|
endByte++;
|
||||||
// copy into a 8n sized buffer to avoid reading off end
|
// copy into a larger buffer to avoid reading off end
|
||||||
MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 4);
|
MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + BYTE_SPECIES.vectorByteSize());
|
||||||
for (long i = endByte; i < buffer.byteSize(); i++)
|
for (long i = endByte; i < buffer.byteSize(); i++)
|
||||||
end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i));
|
end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i));
|
||||||
Stat station = parseStation(0, end, stations);
|
int index = 0;
|
||||||
processTemperature(station.name.length + 1, end, station);
|
while (endByte + index < buffer.byteSize()) {
|
||||||
|
Stat station = parseStation(index, end, stations);
|
||||||
|
index = (int) processTemperature(index + station.name.length + 1, end, station);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
while (startByte < endByte) {
|
while (startByte < endByte) {
|
||||||
|
Loading…
Reference in New Issue
Block a user