ddimtirov - switched to the foreign memory access preview API for another 10% speedup

This commit is contained in:
Dimitar Dimitrov 2024-01-04 04:22:39 +09:00 committed by Gunnar Morling
parent 1923fc65a8
commit d73457872f
2 changed files with 69 additions and 60 deletions

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
# #
# --enable-preview to use the new memory mapped segments
JAVA_OPTS="-XX:+UseParallelGC" # We don't allocate much, so just give it 1G heap and turn off GC; the AlwaysPreTouch was suggested by the ergonomics
JAVA_OPTS="--enable-preview -Xms1g -Xmx1g -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -XX:+AlwaysPreTouch"
time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ddimtirov time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ddimtirov

View File

@ -15,9 +15,10 @@
*/ */
package dev.morling.onebrc; package dev.morling.onebrc;
import java.io.*; import java.io.*;
import java.nio.MappedByteBuffer; import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Files; import java.nio.file.Files;
@ -47,15 +48,16 @@ public class CalculateAverage_ddimtirov {
var start = Instant.now(); var start = Instant.now();
var desiredSegmentsCount = Runtime.getRuntime().availableProcessors(); var desiredSegmentsCount = Runtime.getRuntime().availableProcessors();
var segments = FileSegment.forFile(path, desiredSegmentsCount); var fileSegments = FileSegment.forFile(path, desiredSegmentsCount);
var trackers = segments.stream().parallel().map(segment -> { var trackers = fileSegments.stream().parallel().map(fileSegment -> {
try (var fileChannel = (FileChannel) Files.newByteChannel(path, StandardOpenOption.READ)) { try (var fileChannel = (FileChannel) Files.newByteChannel(path, StandardOpenOption.READ)) {
var tracker = new Tracker(); var tracker = new Tracker();
var segmentBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segment.size()); var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, fileSegment.start(), fileSegment.size(), Arena.ofConfined());
tracker.processSegment(segmentBuffer, segment.end()); tracker.processSegment(memorySegment);
return tracker; return tracker;
} catch (IOException e) { }
catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
}).toList(); }).toList();
@ -63,27 +65,27 @@ public class CalculateAverage_ddimtirov {
var result = summarizeTrackers(trackers); var result = summarizeTrackers(trackers);
System.out.println(result); System.out.println(result);
//noinspection ConstantValue // noinspection ConstantValue
if (start!=null) System.err.println(Duration.between(start, Instant.now())); if (start != null)
assert Files.readAllLines(Path.of("expected_result.txt")).getFirst().equals(result); System.err.println(Duration.between(start, Instant.now()));
assert Files.readAllLines(Path.of("measurements_result.txt")).getFirst().equals(result);
} }
record FileSegment(long start, long size) {
record FileSegment(long start, long end) {
public long size() { return end() - start(); }
public static List<FileSegment> forFile(Path file, int desiredSegmentsCount) throws IOException { public static List<FileSegment> forFile(Path file, int desiredSegmentsCount) throws IOException {
try (var raf = new RandomAccessFile(file.toFile(), "r")) { try (var raf = new RandomAccessFile(file.toFile(), "r")) {
List<FileSegment> segments = new ArrayList<>(); var segments = new ArrayList<FileSegment>();
long fileSize = raf.length(); var fileSize = raf.length();
long segmentSize = fileSize / desiredSegmentsCount; var segmentSize = fileSize / desiredSegmentsCount;
for (int segmentIdx = 0; segmentIdx < desiredSegmentsCount; segmentIdx++) { for (int segmentIdx = 0; segmentIdx < desiredSegmentsCount; segmentIdx++) {
long segStart = segmentIdx * segmentSize; var segStart = segmentIdx * segmentSize;
long segEnd = (segmentIdx == desiredSegmentsCount - 1) ? fileSize : segStart + segmentSize; var segEnd = (segmentIdx == desiredSegmentsCount - 1) ? fileSize : segStart + segmentSize;
segStart = findSegmentBoundary(raf, segmentIdx, 0, segStart, segEnd); segStart = findSegmentBoundary(raf, segmentIdx, 0, segStart, segEnd);
segEnd = findSegmentBoundary(raf, segmentIdx, desiredSegmentsCount - 1, segEnd, fileSize); segEnd = findSegmentBoundary(raf, segmentIdx, desiredSegmentsCount - 1, segEnd, fileSize);
segments.add(new FileSegment(segStart, segEnd)); var segSize = segEnd - segStart;
segments.add(new FileSegment(segStart, segSize));
} }
return segments; return segments;
} }
@ -103,28 +105,33 @@ public class CalculateAverage_ddimtirov {
private static String summarizeTrackers(List<Tracker> trackers) { private static String summarizeTrackers(List<Tracker> trackers) {
var result = new TreeMap<String, String>(); var result = new TreeMap<String, String>();
for (int i = 0; i < HASH_NO_CLASH_MODULUS; i++) { for (var i = 0; i < HASH_NO_CLASH_MODULUS; i++) {
String name = null; String name = null;
int min = Integer.MAX_VALUE; var min = Integer.MAX_VALUE;
int max = Integer.MIN_VALUE; var max = Integer.MIN_VALUE;
long sum = 0; var sum = 0L;
long count = 0; var count = 0L;
for (Tracker tracker : trackers) { for (Tracker tracker : trackers) {
if (tracker.names[i]==null) continue; if (tracker.names[i] == null)
if (name==null) name = tracker.names[i]; continue;
if (name == null)
name = tracker.names[i];
var minn = tracker.minMaxCount[i*3]; var minn = tracker.minMaxCount[i * 3];
var maxx = tracker.minMaxCount[i*3+1]; var maxx = tracker.minMaxCount[i * 3 + 1];
if (minn<min) min = minn; if (minn < min)
if (maxx>max) max = maxx; min = minn;
count += tracker.minMaxCount[i*3+2]; if (maxx > max)
max = maxx;
count += tracker.minMaxCount[i * 3 + 2];
sum += tracker.sums[i]; sum += tracker.sums[i];
} }
if (name==null) continue; if (name == null)
continue;
var mean = Math.round((double) sum / count) / 10.0; var mean = Math.round((double) sum / count) / 10.0;
result.put(name, (min/10.0) + "/" + mean + "/" + (max/10.0)); result.put(name, (min / 10.0) + "/" + mean + "/" + (max / 10.0));
} }
return result.toString(); return result.toString();
} }
@ -133,51 +140,50 @@ public class CalculateAverage_ddimtirov {
private final int[] minMaxCount = new int[HASH_NO_CLASH_MODULUS * 3]; private final int[] minMaxCount = new int[HASH_NO_CLASH_MODULUS * 3];
private final long[] sums = new long[HASH_NO_CLASH_MODULUS]; private final long[] sums = new long[HASH_NO_CLASH_MODULUS];
private final String[] names = new String[HASH_NO_CLASH_MODULUS]; private final String[] names = new String[HASH_NO_CLASH_MODULUS];
private final byte[] nameThreadLocal = new byte[64];
private void processSegment(MappedByteBuffer segmentBuffer, long segmentEnd) { private void processSegment(MemorySegment memory) {
int startLine; int position = 0;
int limit = segmentBuffer.limit(); long limit = memory.byteSize();
while ((startLine = segmentBuffer.position()) < limit) { while (position < limit) {
int pos = startLine; int pos = position;
byte b; byte b;
int nameLength = 0, nameHash = 0; int nameLength = 0, nameHash = 0;
while (pos != segmentEnd && (b = segmentBuffer.get(pos++)) != ';') { while ((b = memory.get(ValueLayout.JAVA_BYTE, pos++)) != ';') {
nameHash = nameHash*31 + b; nameHash = nameHash * 31 + b;
nameLength++; nameLength++;
} }
int temperature = 0, sign = 1; int temperature = 0, sign = 1;
outer: outer: while ((b = memory.get(ValueLayout.JAVA_BYTE, pos++)) != '\n') {
while (pos != segmentEnd && (b = segmentBuffer.get(pos++)) != '\n') {
switch (b) { switch (b) {
case '\r' : case '\r':
pos++; pos++;
break outer; break outer;
case '.' : case '.':
break; break;
case '-' : case '-':
sign = -1; sign = -1;
break; break;
default : default:
var digit = b - '0'; var digit = b - '0';
assert digit >= 0 && digit <= 9; assert digit >= 0 && digit <= 9;
temperature = 10 * temperature + digit; temperature = 10 * temperature + digit;
} }
} }
processLine(nameHash, segmentBuffer, startLine, nameLength, temperature * sign); processLine(nameHash, memory, position, nameLength, temperature * sign);
segmentBuffer.position(pos); position = pos;
} }
} }
public void processLine(int nameHash, MappedByteBuffer buffer, int nameOffset, int nameLength, int temperature) { public void processLine(int nameHash, MemorySegment buffer, int nameOffset, int nameLength, int temperature) {
var i = Math.abs(nameHash) % HASH_NO_CLASH_MODULUS; var i = Math.abs(nameHash) % HASH_NO_CLASH_MODULUS;
if (names[i]==null) { if (names[i] == null) {
names[i] = parseName(buffer, nameOffset, nameLength); names[i] = parseName(buffer, nameOffset, nameLength);
} else { }
else {
assert parseName(buffer, nameOffset, nameLength).equals(names[i]) : parseName(buffer, nameOffset, nameLength) + "!=" + names[i]; assert parseName(buffer, nameOffset, nameLength).equals(names[i]) : parseName(buffer, nameOffset, nameLength) + "!=" + names[i];
} }
@ -186,15 +192,17 @@ public class CalculateAverage_ddimtirov {
int mmcIndex = i * 3; int mmcIndex = i * 3;
var min = minMaxCount[mmcIndex + OFFSET_MIN]; var min = minMaxCount[mmcIndex + OFFSET_MIN];
var max = minMaxCount[mmcIndex + OFFSET_MAX]; var max = minMaxCount[mmcIndex + OFFSET_MAX];
if (temperature < min) minMaxCount[mmcIndex + OFFSET_MIN] = temperature; if (temperature < min)
if (temperature > max) minMaxCount[mmcIndex + OFFSET_MAX] = temperature; minMaxCount[mmcIndex + OFFSET_MIN] = temperature;
if (temperature > max)
minMaxCount[mmcIndex + OFFSET_MAX] = temperature;
minMaxCount[mmcIndex + OFFSET_COUNT]++; minMaxCount[mmcIndex + OFFSET_COUNT]++;
} }
private String parseName(MappedByteBuffer buffer, int nameOffset, int nameLength) { private String parseName(MemorySegment memory, int nameOffset, int nameLength) {
buffer.get(nameOffset, nameThreadLocal, 0, nameLength); byte[] array = memory.asSlice(nameOffset, nameLength).toArray(ValueLayout.JAVA_BYTE);
return new String(nameThreadLocal, 0, nameLength, StandardCharsets.UTF_8); return new String(array, StandardCharsets.UTF_8);
} }
} }
} }