3s (16%) faster, still no unsafe (#478)

* use Arena and MemorySegment to map entire file at once
* reduced branches and instructions
This commit is contained in:
Dr Ian Preston 2024-01-19 16:14:45 +00:00 committed by GitHub
parent 9b28dd2aec
commit fefe326a14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,45 +15,53 @@
*/ */
package dev.morling.onebrc; package dev.morling.onebrc;
import java.io.*; import java.lang.foreign.Arena;
import java.nio.*; import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
import java.nio.channels.*; import java.nio.channels.*;
import java.util.concurrent.*; import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.stream.*; import java.util.stream.*;
import java.util.*; import java.util.*;
import static java.lang.foreign.ValueLayout.*;
/* A fast implementation with no unsafe. /* A fast implementation with no unsafe.
* Features: * Features:
* * memory mapped file * * memory mapped file using preview Arena FFI
* * read chunks in parallel * * read chunks in parallel
* * minimise allocation * * minimise allocation
* * no unsafe * * no unsafe
* *
* Timings on 4 core i7-7500U CPU @ 2.70GHz: * Timings on 4 core i7-7500U CPU @ 2.70GHz:
* average_baseline: 4m48s * average_baseline: 4m48s
* ianopolous: 19s * ianopolous: 16s
*/ */
public class CalculateAverage_ianopolousfast { public class CalculateAverage_ianopolousfast {
public static final int MAX_LINE_LENGTH = 107; public static final int MAX_LINE_LENGTH = 107;
public static final int MAX_STATIONS = 10_000; public static final int MAX_STATIONS = 1 << 14;
private static final OfLong LONG_LAYOUT = JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
File input = new File("./measurements.txt"); Arena arena = Arena.global();
long filesize = input.length(); Path input = Path.of("measurements.txt");
// keep chunk size between 256 MB and 1G (1 chunk for files < 256MB) FileChannel channel = (FileChannel) Files.newByteChannel(input, StandardOpenOption.READ);
long chunkSize = Math.min(Math.max((filesize + 31) / 32, 256 * 1024 * 1024), 1024 * 1024 * 1024L); long filesize = Files.size(input);
int nChunks = (int) ((filesize + chunkSize - 1) / chunkSize); MemorySegment mmap = channel.map(FileChannel.MapMode.READ_ONLY, 0, filesize, arena);
ExecutorService pool = Executors.newVirtualThreadPerTaskExecutor(); int nChunks = filesize < 4 * 1024 * 1024 ? 1 : Runtime.getRuntime().availableProcessors();
List<Future<List<List<Stat>>>> allResults = IntStream.range(0, nChunks) long chunkSize = (filesize + nChunks - 1) / nChunks;
.mapToObj(i -> pool.submit(() -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize)))) List<List<List<Stat>>> allResults = IntStream.range(0, nChunks)
.parallel()
.mapToObj(i -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize), mmap))
.toList(); .toList();
TreeMap<String, Stat> merged = allResults.stream() TreeMap<String, Stat> merged = allResults.stream()
.parallel() .parallel()
.flatMap(f -> { .flatMap(f -> {
try { try {
return f.get().stream().filter(Objects::nonNull).flatMap(Collection::stream); return f.stream().filter(Objects::nonNull).flatMap(Collection::stream);
} }
catch (Exception e) { catch (Exception e) {
e.printStackTrace(); e.printStackTrace();
@ -64,25 +72,39 @@ public class CalculateAverage_ianopolousfast {
System.out.println(merged); System.out.println(merged);
} }
public static boolean matchingStationBytes(int start, int end, ByteBuffer buffer, Stat existing) { public static boolean matchingStationBytes(long start, long end, int offset, MemorySegment buffer, Stat existing) {
if (end - start != existing.name.length) int len = (int) (end - start);
if (len != existing.name.length)
return false; return false;
for (int i = start; i < end; i++) { for (int i = offset; i < len; i++) {
if (existing.name[i - start] != buffer.get(i)) if (existing.name[i] != buffer.get(JAVA_BYTE, offset + start++))
return false; return false;
} }
return true; return true;
} }
public static Stat dedupeStation(int start, int end, long hash, ByteBuffer buffer, List<List<Stat>> stations) { private static int hashToIndex(long hash, int len) {
int index = Math.floorMod(hash ^ (hash >> 32), MAX_STATIONS); // From Thomas Wuerthinger's entry
int hashAsInt = (int) (hash ^ (hash >>> 28));
int finalHash = (hashAsInt ^ (hashAsInt >>> 15));
return (finalHash & (len - 1));
}
public static Stat parseStation(long start, long end, long first8, long second8,
MemorySegment buffer) {
byte[] stationBuffer = new byte[(int) (end - start)];
for (long off = start; off < end; off++)
stationBuffer[(int) (off - start)] = buffer.get(JAVA_BYTE, off);
return new Stat(stationBuffer, first8, second8);
}
public static Stat dedupeStation(long start, long end, long hash, long first8, long second8,
MemorySegment buffer, List<List<Stat>> stations) {
int index = hashToIndex(hash, MAX_STATIONS);
List<Stat> matches = stations.get(index); List<Stat> matches = stations.get(index);
if (matches == null) { if (matches == null) {
List<Stat> value = new ArrayList<>(); List<Stat> value = new ArrayList<>();
byte[] stationBuffer = new byte[end - start]; Stat res = parseStation(start, end, first8, second8, buffer);
buffer.position(start);
buffer.get(stationBuffer);
Stat res = new Stat(stationBuffer);
value.add(res); value.add(res);
stations.set(index, value); stations.set(index, value);
return res; return res;
@ -90,92 +112,151 @@ public class CalculateAverage_ianopolousfast {
else { else {
for (int i = 0; i < matches.size(); i++) { for (int i = 0; i < matches.size(); i++) {
Stat s = matches.get(i); Stat s = matches.get(i);
if (matchingStationBytes(start, end, buffer, s)) if (first8 == s.first8 && second8 == s.second8 && matchingStationBytes(start, end, 16, buffer, s))
return s; return s;
} }
byte[] stationBuffer = new byte[end - start]; Stat res = parseStation(start, end, first8, second8, buffer);
buffer.position(start);
buffer.get(stationBuffer);
Stat res = new Stat(stationBuffer);
matches.add(res); matches.add(res);
return res; return res;
} }
} }
public static int getSemicolon(long d) { public static Stat dedupeStation8(long start, long end, long hash, long first8, MemorySegment buffer, List<List<Stat>> stations) {
int index = hashToIndex(hash, MAX_STATIONS);
List<Stat> matches = stations.get(index);
if (matches == null) {
List<Stat> value = new ArrayList<>();
Stat station = parseStation(start, end, first8, 0, buffer);
value.add(station);
stations.set(index, value);
return station;
}
else {
for (int i = 0; i < matches.size(); i++) {
Stat s = matches.get(i);
if (first8 == s.first8 && s.name.length <= 8)
return s;
}
Stat station = parseStation(start, end, first8, 0, buffer);
matches.add(station);
return station;
}
}
public static Stat dedupeStation16(long start, long end, long hash, long first8, long second8, MemorySegment buffer, List<List<Stat>> stations) {
int index = hashToIndex(hash, MAX_STATIONS);
List<Stat> matches = stations.get(index);
if (matches == null) {
List<Stat> value = new ArrayList<>();
Stat res = parseStation(start, end, first8, second8, buffer);
value.add(res);
stations.set(index, value);
return res;
}
else {
for (int i = 0; i < matches.size(); i++) {
Stat s = matches.get(i);
if (first8 == s.first8 && second8 == s.second8 && s.name.length <= 16)
return s;
}
Stat res = parseStation(start, end, first8, second8, buffer);
matches.add(res);
return res;
}
}
public static long hasSemicolon(long d) {
// from Hacker's Delight page 92 // from Hacker's Delight page 92
d = d ^ 0x3b3b3b3b3b3b3b3bL; d = d ^ 0x3b3b3b3b3b3b3b3bL;
long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL; long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL;
return ~(y | d | 0x7f7f7f7f7f7f7f7fL);
}
public static int getSemicolonIndex(long y) {
// from Hacker's Delight page 92
return Long.numberOfLeadingZeros(y) >> 3;
}
static long maskHighBytes(long d, int nbytes) {
return d & (-1L << ((8 - nbytes) * 8));
}
public static Stat parseStation(long lineStart, MemorySegment buffer, List<List<Stat>> stations) {
// find semicolon and update hash as we go, reading a long at a time
long d = buffer.get(LONG_LAYOUT, lineStart);
long hasSemi = hasSemicolon(d);
if (hasSemi != 0) {
int semiIndex = getSemicolonIndex(hasSemi);
d = maskHighBytes(d, semiIndex);
return dedupeStation8(lineStart, lineStart + semiIndex, d, d, buffer, stations);
}
long first8 = d;
long hash = d;
d = buffer.get(LONG_LAYOUT, lineStart + 8);
hasSemi = hasSemicolon(d);
if (hasSemi != 0) {
int semiIndex = getSemicolonIndex(hasSemi);
if (semiIndex == 0)
return dedupeStation8(lineStart, lineStart + 8, first8, first8, buffer, stations);
d = maskHighBytes(d, semiIndex);
return dedupeStation16(lineStart, lineStart + 8 + semiIndex, first8 ^ d, first8, d, buffer, stations);
}
int index = 8;
long second8 = d;
while (hasSemi == 0) {
hash = hash ^ d;
index += 8;
d = buffer.get(LONG_LAYOUT, lineStart + index);
hasSemi = hasSemicolon(d);
}
int semiIndex = getSemicolonIndex(hasSemi);
d = maskHighBytes(d, semiIndex);
if (semiIndex > 0) {
hash = hash ^ d;
}
return dedupeStation(lineStart, lineStart + index + semiIndex, hash, first8, second8, buffer, stations);
}
public static int getDot(long d) {
// from Hacker's Delight page 92
d = d ^ 0x2e2e2e2e2e2e2e2eL;
long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL;
y = ~(y | d | 0x7f7f7f7f7f7f7f7fL); y = ~(y | d | 0x7f7f7f7f7f7f7f7fL);
return Long.numberOfLeadingZeros(y) >> 3; return Long.numberOfLeadingZeros(y) >> 3;
} }
public static long updateHash(long hash, long x) { public static short getMinus(long d) {
return ((hash << 5) ^ x) * 0x517cc1b727220a95L; // fxHash d = d & 0xff00000000000000L;
d = d ^ 0x2d2d2d2d2d2d2d2dL;
long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL;
y = ~(y | d | 0x7f7f7f7f7f7f7f7fL);
return (short) ((Long.numberOfLeadingZeros(y) >> 6) - 1);
} }
public static Stat parseStation(int lineStart, ByteBuffer buffer, List<List<Stat>> stations) { public static long processTemperature(long lineSplit, MemorySegment buffer, Stat station) {
// find semicolon and update hash as we go, reading a long at a time long d = buffer.get(LONG_LAYOUT, lineSplit);
long d = buffer.getLong(lineStart); // negative is either 0 or -1
short negative = getMinus(d);
int semiIndex = getSemicolon(d); d = d << (negative * -8);
int index = 0; int dotIndex = getDot(d);
long hash = 0; d = (d >> 8) | 0x30000000_00000000L; // add a leading 0 digit
while (semiIndex == 8) { d = d >> 8 * (5 - dotIndex);
hash = updateHash(hash, d); short temperature = (short) ((byte) d - '0' +
index += 8; 10 * (((byte) (d >> 16)) - '0') +
d = buffer.getLong(lineStart + index); 100 * (((byte) (d >> 24)) - '0'));
semiIndex = getSemicolon(d); temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty
}
// mask extra bytes off last long
d = d & (-1L << ((8 - semiIndex) * 8));
if (semiIndex > 0) {
hash = updateHash(hash, d);
}
return dedupeStation(lineStart, lineStart + index + semiIndex, hash, buffer, stations);
}
public static int processTemperature(int lineSplit, MappedByteBuffer buffer, Stat station) {
short temperature;
boolean negative = false;
byte b = buffer.get(lineSplit++);
if (b == '-') {
negative = true;
b = buffer.get(lineSplit++);
}
temperature = (short) (b - 0x30);
b = buffer.get(lineSplit++);
if (b == '.') {
b = buffer.get(lineSplit++);
temperature = (short) (temperature * 10 + (b - 0x30));
}
else {
temperature = (short) (temperature * 10 + (b - 0x30));
lineSplit++;
b = buffer.get(lineSplit++);
temperature = (short) (temperature * 10 + (b - 0x30));
}
temperature = negative ? (short) -temperature : temperature;
station.add(temperature); station.add(temperature);
return lineSplit + 1; return lineSplit - negative + dotIndex + 3;
} }
public static List<List<Stat>> parseStats(long startByte, long endByte) { public static List<List<Stat>> parseStats(long startByte, long endByte, MemorySegment buffer) {
try {
RandomAccessFile file = new RandomAccessFile("./measurements.txt", "r");
long maxEnd = Math.min(file.length(), endByte + MAX_LINE_LENGTH);
long len = maxEnd - startByte;
if (len > Integer.MAX_VALUE)
throw new RuntimeException("Segment size must fit into an int");
int maxDone = (int) (endByte - startByte);
MappedByteBuffer buffer = file.getChannel().map(FileChannel.MapMode.READ_ONLY, startByte, len);
int done = 0;
// read first partial line // read first partial line
if (startByte > 0) { if (startByte > 0) {
for (int i = 0; i < MAX_LINE_LENGTH; i++) { for (int i = 0; i < MAX_LINE_LENGTH; i++) {
byte b = buffer.get(i); byte b = buffer.get(JAVA_BYTE, startByte++);
if (b == '\n') { if (b == '\n') {
done = i + 1;
break; break;
} }
} }
@ -190,46 +271,39 @@ public class CalculateAverage_ianopolousfast {
// in the inner loop (reducing branches) // in the inner loop (reducing branches)
// We only need to read one because the min record size is 6 bytes // We only need to read one because the min record size is 6 bytes
// so 2nd last record must be > 8 from end // so 2nd last record must be > 8 from end
if (endByte == file.length()) { if (endByte == buffer.byteSize()) {
int offset = (int) (file.length() - startByte - 1); endByte -= 2; // skip final new line
while (buffer.get(offset) != '\n') // final new line while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
offset--; endByte--;
offset--;
while (offset > 0 && buffer.get(offset) != '\n') // end of second last line if (endByte > 0)
offset--; endByte++;
maxDone = offset;
if (offset > 0)
offset++;
// copy into a 8n sized buffer to avoid reading off end // copy into a 8n sized buffer to avoid reading off end
int roundedSize = (int) (file.length() - startByte) - offset; MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 4);
roundedSize = (roundedSize + 7) / 8 * 8; for (long i = endByte; i < buffer.byteSize(); i++)
byte[] end = new byte[roundedSize]; end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i));
for (int i = offset; i < (int) (file.length() - startByte); i++) Stat station = parseStation(0, end, stations);
end[i - offset] = buffer.get(i); processTemperature(station.name.length + 1, end, station);
Stat station = parseStation(0, ByteBuffer.wrap(end), stations);
processTemperature(offset + station.name.length + 1, buffer, station);
} }
int lineStart = done; while (startByte < endByte) {
while (lineStart < maxDone) { Stat station = parseStation(startByte, buffer, stations);
Stat station = parseStation(lineStart, buffer, stations); startByte = processTemperature(startByte + station.name.length + 1, buffer, station);
lineStart = processTemperature(lineStart + station.name.length + 1, buffer, station);
} }
return stations; return stations;
} }
catch (IOException e) {
throw new RuntimeException(e);
}
}
public static class Stat { public static class Stat {
final byte[] name; final byte[] name;
int count = 0; int count = 0;
short min = Short.MAX_VALUE, max = Short.MIN_VALUE; short min = Short.MAX_VALUE, max = Short.MIN_VALUE;
long total = 0; long total = 0;
final long first8, second8;
public Stat(byte[] name) { public Stat(byte[] name, long first8, long second8) {
this.name = name; this.name = name;
this.first8 = first8;
this.second8 = second8;
} }
public void add(short value) { public void add(short value) {