Down to 14s locally (#583)
Use flat array for stats. Use simd for line termination Co-authored-by: Ian Preston <ianopolous@protonmail.com>
This commit is contained in:
parent
d5cedd6a35
commit
0bd1675571
@ -41,7 +41,7 @@ import static java.lang.foreign.ValueLayout.*;
|
|||||||
*
|
*
|
||||||
* Timings on 4 core i7-7500U CPU @ 2.70GHz:
|
* Timings on 4 core i7-7500U CPU @ 2.70GHz:
|
||||||
* average_baseline: 4m48s
|
* average_baseline: 4m48s
|
||||||
* ianopolous: 15s
|
* ianopolous: 14s
|
||||||
*/
|
*/
|
||||||
public class CalculateAverage_ianopolousfast {
|
public class CalculateAverage_ianopolousfast {
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
MemorySegment mmap = channel.map(FileChannel.MapMode.READ_ONLY, 0, filesize, arena);
|
MemorySegment mmap = channel.map(FileChannel.MapMode.READ_ONLY, 0, filesize, arena);
|
||||||
int nChunks = filesize < 4 * 1024 * 1024 ? 1 : Runtime.getRuntime().availableProcessors();
|
int nChunks = filesize < 4 * 1024 * 1024 ? 1 : Runtime.getRuntime().availableProcessors();
|
||||||
long chunkSize = (filesize + nChunks - 1) / nChunks;
|
long chunkSize = (filesize + nChunks - 1) / nChunks;
|
||||||
List<List<List<Stat>>> allResults = IntStream.range(0, nChunks)
|
List<Stat[]> allResults = IntStream.range(0, nChunks)
|
||||||
.parallel()
|
.parallel()
|
||||||
.mapToObj(i -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize), mmap))
|
.mapToObj(i -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize), mmap))
|
||||||
.toList();
|
.toList();
|
||||||
@ -69,7 +69,7 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
.parallel()
|
.parallel()
|
||||||
.flatMap(f -> {
|
.flatMap(f -> {
|
||||||
try {
|
try {
|
||||||
return f.stream().filter(Objects::nonNull).flatMap(Collection::stream);
|
return Arrays.stream(f).filter(Objects::nonNull);
|
||||||
}
|
}
|
||||||
catch (Exception e) {
|
catch (Exception e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
@ -104,24 +104,23 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
return new Stat(stationBuffer);
|
return new Stat(stationBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, List<List<Stat>> stations) {
|
public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, Stat[] stations) {
|
||||||
int index = hashToIndex(hash, MAX_STATIONS);
|
int index = hashToIndex(hash, MAX_STATIONS);
|
||||||
List<Stat> matches = stations.get(index);
|
Stat match = stations[index];
|
||||||
if (matches == null) {
|
if (match == null) {
|
||||||
List<Stat> value = new ArrayList<>();
|
|
||||||
Stat res = createStation(start, end, buffer);
|
Stat res = createStation(start, end, buffer);
|
||||||
value.add(res);
|
stations[index] = res;
|
||||||
stations.set(index, value);
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (int i = 0; i < matches.size(); i++) {
|
while (match != null) {
|
||||||
Stat s = matches.get(i);
|
if (matchingStationBytes(start, end, buffer, match))
|
||||||
if (matchingStationBytes(start, end, buffer, s))
|
return match;
|
||||||
return s;
|
index = (index + 1) % stations.length;
|
||||||
|
match = stations[index];
|
||||||
}
|
}
|
||||||
Stat res = createStation(start, end, buffer);
|
Stat res = createStation(start, end, buffer);
|
||||||
matches.add(res);
|
stations[index] = res;
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -130,50 +129,38 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
return d & (-1L << ((8 - nbytes) * 8));
|
return d & (-1L << ((8 - nbytes) * 8));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Stat parseStation(long lineStart, MemorySegment buffer, List<List<Stat>> stations) {
|
public static Stat parseStation(long lineStart, MemorySegment buffer, Stat[] stations) {
|
||||||
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
|
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
|
||||||
int keySize = line.compare(VectorOperators.EQ, ';').firstTrue();
|
int keySize = line.compare(VectorOperators.EQ, ';').firstTrue();
|
||||||
|
|
||||||
long first8 = buffer.get(LONG_LAYOUT, lineStart);
|
long first8 = buffer.get(LONG_LAYOUT, lineStart);
|
||||||
if (keySize == BYTE_SPECIES.vectorByteSize()) {
|
long second8 = 0;
|
||||||
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
|
|
||||||
keySize++;
|
|
||||||
}
|
|
||||||
long second8 = buffer.get(LONG_LAYOUT, lineStart + 8);
|
|
||||||
long hash = first8 ^ second8; // todo include other bytes
|
|
||||||
return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (keySize <= 8) {
|
if (keySize <= 8) {
|
||||||
first8 = maskHighBytes(first8, keySize & 0x07);
|
first8 = maskHighBytes(first8, keySize & 0x07);
|
||||||
}
|
}
|
||||||
long second8 = keySize <= 8 ? 0 : maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
|
else if (keySize <= 16) {
|
||||||
|
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
|
||||||
|
}
|
||||||
|
else if (keySize == BYTE_SPECIES.vectorByteSize()) {
|
||||||
|
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
|
||||||
|
keySize++;
|
||||||
|
}
|
||||||
|
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
|
||||||
|
}
|
||||||
long hash = first8 ^ second8; // todo include later bytes
|
long hash = first8 ^ second8; // todo include later bytes
|
||||||
return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
|
return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int getDot(long d) {
|
|
||||||
// from Hacker's Delight page 92
|
|
||||||
d = d ^ 0x2e2e2e2e2e2e2e2eL;
|
|
||||||
long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL;
|
|
||||||
y = ~(y | d | 0x7f7f7f7f7f7f7f7fL);
|
|
||||||
return Long.numberOfLeadingZeros(y) >> 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static short getMinus(long d) {
|
public static short getMinus(long d) {
|
||||||
d = d & 0xff00000000000000L;
|
return ((d & 0xff00000000000000L) ^ 0x2d00000000000000L) != 0 ? 0 : (short) -1;
|
||||||
d = d ^ 0x2d2d2d2d2d2d2d2dL;
|
|
||||||
long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL;
|
|
||||||
y = ~(y | d | 0x7f7f7f7f7f7f7f7fL);
|
|
||||||
return (short) ((Long.numberOfLeadingZeros(y) >> 6) - 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static long processTemperature(long lineSplit, MemorySegment buffer, Stat station) {
|
public static long processTemperature(long lineSplit, int size, MemorySegment buffer, Stat station) {
|
||||||
long d = buffer.get(LONG_LAYOUT, lineSplit);
|
long d = buffer.get(LONG_LAYOUT, lineSplit);
|
||||||
// negative is either 0 or -1
|
// negative is either 0 or -1
|
||||||
short negative = getMinus(d);
|
short negative = getMinus(d);
|
||||||
d = d << (negative * -8);
|
d = d << (negative * -8);
|
||||||
int dotIndex = getDot(d);
|
int dotIndex = size - 2 + negative;
|
||||||
d = (d >> 8) | 0x30000000_00000000L; // add a leading 0 digit
|
d = (d >> 8) | 0x30000000_00000000L; // add a leading 0 digit
|
||||||
d = d >> 8 * (5 - dotIndex);
|
d = d >> 8 * (5 - dotIndex);
|
||||||
short temperature = (short) ((byte) d - '0' +
|
short temperature = (short) ((byte) d - '0' +
|
||||||
@ -181,10 +168,41 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
100 * (((byte) (d >> 24)) - '0'));
|
100 * (((byte) (d >> 24)) - '0'));
|
||||||
temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty
|
temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty
|
||||||
station.add(temperature);
|
station.add(temperature);
|
||||||
return lineSplit - negative + dotIndex + 3;
|
return lineSplit + size + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<List<Stat>> parseStats(long startByte, long endByte, MemorySegment buffer) {
|
private static long parseLine(long lineStart, MemorySegment buffer, Stat[] stations) {
|
||||||
|
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
|
||||||
|
int lineSize = line.compare(VectorOperators.EQ, '\n').firstTrue();
|
||||||
|
int index = lineSize;
|
||||||
|
while (index == BYTE_SPECIES.vectorByteSize()) {
|
||||||
|
index = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize,
|
||||||
|
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, '\n').firstTrue();
|
||||||
|
lineSize += index;
|
||||||
|
}
|
||||||
|
int keySize = lineSize - 6 + ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize - 6,
|
||||||
|
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue();
|
||||||
|
|
||||||
|
long first8 = buffer.get(LONG_LAYOUT, lineStart);
|
||||||
|
long second8 = 0;
|
||||||
|
if (keySize <= 8) {
|
||||||
|
first8 = maskHighBytes(first8, keySize & 0x07);
|
||||||
|
}
|
||||||
|
else if (keySize <= 16) {
|
||||||
|
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
|
||||||
|
}
|
||||||
|
else if (keySize == BYTE_SPECIES.vectorByteSize()) {
|
||||||
|
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
|
||||||
|
keySize++;
|
||||||
|
}
|
||||||
|
second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
|
||||||
|
}
|
||||||
|
long hash = first8 ^ second8; // todo include later bytes
|
||||||
|
Stat station = dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
|
||||||
|
return processTemperature(lineStart + keySize + 1, lineSize - keySize - 1, buffer, station);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Stat[] parseStats(long startByte, long endByte, MemorySegment buffer) {
|
||||||
// read first partial line
|
// read first partial line
|
||||||
if (startByte > 0) {
|
if (startByte > 0) {
|
||||||
for (int i = 0; i < MAX_LINE_LENGTH; i++) {
|
for (int i = 0; i < MAX_LINE_LENGTH; i++) {
|
||||||
@ -195,9 +213,7 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
List<List<Stat>> stations = new ArrayList<>(MAX_STATIONS);
|
Stat[] stations = new Stat[MAX_STATIONS];
|
||||||
for (int i = 0; i < MAX_STATIONS; i++)
|
|
||||||
stations.add(null);
|
|
||||||
|
|
||||||
// Handle reading the very last few lines in the file
|
// Handle reading the very last few lines in the file
|
||||||
// this allows us to not worry about reading beyond the end
|
// this allows us to not worry about reading beyond the end
|
||||||
@ -218,7 +234,12 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
int index = 0;
|
int index = 0;
|
||||||
while (endByte + index < buffer.byteSize()) {
|
while (endByte + index < buffer.byteSize()) {
|
||||||
Stat station = parseStation(index, end, stations);
|
Stat station = parseStation(index, end, stations);
|
||||||
index = (int) processTemperature(index + station.namelen + 1, end, station);
|
int tempSize = 3;
|
||||||
|
if (end.get(JAVA_BYTE, index + station.namelen + 5) == '\n')
|
||||||
|
tempSize = 4;
|
||||||
|
if (end.get(JAVA_BYTE, index + station.namelen + 6) == '\n')
|
||||||
|
tempSize = 5;
|
||||||
|
index = (int) processTemperature(index + station.namelen + 1, tempSize, end, station);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -226,10 +247,9 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
return stations;
|
return stations;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void innerloop(long startByte, long endByte, MemorySegment buffer, List<List<Stat>> stations) {
|
private static void innerloop(long startByte, long endByte, MemorySegment buffer, Stat[] stations) {
|
||||||
while (startByte < endByte) {
|
while (startByte < endByte) {
|
||||||
Stat station = parseStation(startByte, buffer, stations);
|
startByte = parseLine(startByte, buffer, stations);
|
||||||
startByte = processTemperature(startByte + station.namelen + 1, buffer, station);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,4 +298,4 @@ public class CalculateAverage_ianopolousfast {
|
|||||||
return round((double) min) + "/" + round(((double) total) / count) + "/" + round((double) max);
|
return round((double) min) + "/" + round(((double) total) / count) + "/" + round((double) max);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user