From d73457872f0b9990ab6258d0a96d3edad3b24b1a Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Thu, 4 Jan 2024 04:22:39 +0900 Subject: [PATCH] ddimtirov - switched to the foreign memory access preview API for another 10% speedup --- calculate_average_ddimtirov.sh | 5 +- .../onebrc/CalculateAverage_ddimtirov.java | 124 ++++++++++-------- 2 files changed, 69 insertions(+), 60 deletions(-) diff --git a/calculate_average_ddimtirov.sh b/calculate_average_ddimtirov.sh index 26fb652..94a1981 100644 --- a/calculate_average_ddimtirov.sh +++ b/calculate_average_ddimtirov.sh @@ -15,6 +15,7 @@ # limitations under the License. # - -JAVA_OPTS="-XX:+UseParallelGC" +# --enable-preview to use the new memory mapped segments +# We don't allocate much, so just give it 1G heap and turn off GC; the AlwaysPreTouch was suggested by the ergonomics +JAVA_OPTS="--enable-preview -Xms1g -Xmx1g -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -XX:+AlwaysPreTouch" time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ddimtirov diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java index 5a9261b..7f5ac50 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java @@ -15,9 +15,10 @@ */ package dev.morling.onebrc; - import java.io.*; -import java.nio.MappedByteBuffer; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Files; @@ -47,15 +48,16 @@ public class CalculateAverage_ddimtirov { var start = Instant.now(); var desiredSegmentsCount = Runtime.getRuntime().availableProcessors(); - var segments = FileSegment.forFile(path, desiredSegmentsCount); + var fileSegments = FileSegment.forFile(path, desiredSegmentsCount); - var trackers = segments.stream().parallel().map(segment -> { + var trackers = fileSegments.stream().parallel().map(fileSegment -> { try (var fileChannel = (FileChannel) Files.newByteChannel(path, StandardOpenOption.READ)) { var tracker = new Tracker(); - var segmentBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segment.size()); - tracker.processSegment(segmentBuffer, segment.end()); + var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, fileSegment.start(), fileSegment.size(), Arena.ofConfined()); + tracker.processSegment(memorySegment); return tracker; - } catch (IOException e) { + } + catch (IOException e) { throw new RuntimeException(e); } }).toList(); @@ -63,27 +65,27 @@ public class CalculateAverage_ddimtirov { var result = summarizeTrackers(trackers); System.out.println(result); - //noinspection ConstantValue - if (start!=null) System.err.println(Duration.between(start, Instant.now())); - assert Files.readAllLines(Path.of("expected_result.txt")).getFirst().equals(result); + // noinspection ConstantValue + if (start != null) + System.err.println(Duration.between(start, Instant.now())); + assert Files.readAllLines(Path.of("measurements_result.txt")).getFirst().equals(result); } - - record FileSegment(long start, long end) { - public long size() { return end() - start(); } - + record FileSegment(long start, long size) { public static List forFile(Path file, int desiredSegmentsCount) throws IOException { try (var raf = new RandomAccessFile(file.toFile(), "r")) { - List segments = new ArrayList<>(); - long fileSize = raf.length(); - long segmentSize = fileSize / desiredSegmentsCount; + var segments = new ArrayList(); + var fileSize = raf.length(); + var segmentSize = fileSize / desiredSegmentsCount; for (int segmentIdx = 0; segmentIdx < desiredSegmentsCount; segmentIdx++) { - long segStart = segmentIdx * segmentSize; - long segEnd = (segmentIdx == desiredSegmentsCount - 1) ? fileSize : segStart + segmentSize; + var segStart = segmentIdx * segmentSize; + var segEnd = (segmentIdx == desiredSegmentsCount - 1) ? fileSize : segStart + segmentSize; segStart = findSegmentBoundary(raf, segmentIdx, 0, segStart, segEnd); segEnd = findSegmentBoundary(raf, segmentIdx, desiredSegmentsCount - 1, segEnd, fileSize); - segments.add(new FileSegment(segStart, segEnd)); + var segSize = segEnd - segStart; + + segments.add(new FileSegment(segStart, segSize)); } return segments; } @@ -103,28 +105,33 @@ public class CalculateAverage_ddimtirov { private static String summarizeTrackers(List trackers) { var result = new TreeMap(); - for (int i = 0; i < HASH_NO_CLASH_MODULUS; i++) { + for (var i = 0; i < HASH_NO_CLASH_MODULUS; i++) { String name = null; - int min = Integer.MAX_VALUE; - int max = Integer.MIN_VALUE; - long sum = 0; - long count = 0; + var min = Integer.MAX_VALUE; + var max = Integer.MIN_VALUE; + var sum = 0L; + var count = 0L; for (Tracker tracker : trackers) { - if (tracker.names[i]==null) continue; - if (name==null) name = tracker.names[i]; + if (tracker.names[i] == null) + continue; + if (name == null) + name = tracker.names[i]; - var minn = tracker.minMaxCount[i*3]; - var maxx = tracker.minMaxCount[i*3+1]; - if (minnmax) max = maxx; - count += tracker.minMaxCount[i*3+2]; + var minn = tracker.minMaxCount[i * 3]; + var maxx = tracker.minMaxCount[i * 3 + 1]; + if (minn < min) + min = minn; + if (maxx > max) + max = maxx; + count += tracker.minMaxCount[i * 3 + 2]; sum += tracker.sums[i]; } - if (name==null) continue; + if (name == null) + continue; var mean = Math.round((double) sum / count) / 10.0; - result.put(name, (min/10.0) + "/" + mean + "/" + (max/10.0)); + result.put(name, (min / 10.0) + "/" + mean + "/" + (max / 10.0)); } return result.toString(); } @@ -133,51 +140,50 @@ public class CalculateAverage_ddimtirov { private final int[] minMaxCount = new int[HASH_NO_CLASH_MODULUS * 3]; private final long[] sums = new long[HASH_NO_CLASH_MODULUS]; private final String[] names = new String[HASH_NO_CLASH_MODULUS]; - private final byte[] nameThreadLocal = new byte[64]; - private void processSegment(MappedByteBuffer segmentBuffer, long segmentEnd) { - int startLine; - int limit = segmentBuffer.limit(); - while ((startLine = segmentBuffer.position()) < limit) { - int pos = startLine; + private void processSegment(MemorySegment memory) { + int position = 0; + long limit = memory.byteSize(); + while (position < limit) { + int pos = position; byte b; int nameLength = 0, nameHash = 0; - while (pos != segmentEnd && (b = segmentBuffer.get(pos++)) != ';') { - nameHash = nameHash*31 + b; + while ((b = memory.get(ValueLayout.JAVA_BYTE, pos++)) != ';') { + nameHash = nameHash * 31 + b; nameLength++; } int temperature = 0, sign = 1; - outer: - while (pos != segmentEnd && (b = segmentBuffer.get(pos++)) != '\n') { + outer: while ((b = memory.get(ValueLayout.JAVA_BYTE, pos++)) != '\n') { switch (b) { - case '\r' : + case '\r': pos++; break outer; - case '.' : + case '.': break; - case '-' : + case '-': sign = -1; break; - default : + default: var digit = b - '0'; assert digit >= 0 && digit <= 9; temperature = 10 * temperature + digit; } } - processLine(nameHash, segmentBuffer, startLine, nameLength, temperature * sign); - segmentBuffer.position(pos); + processLine(nameHash, memory, position, nameLength, temperature * sign); + position = pos; } } - public void processLine(int nameHash, MappedByteBuffer buffer, int nameOffset, int nameLength, int temperature) { + public void processLine(int nameHash, MemorySegment buffer, int nameOffset, int nameLength, int temperature) { var i = Math.abs(nameHash) % HASH_NO_CLASH_MODULUS; - if (names[i]==null) { + if (names[i] == null) { names[i] = parseName(buffer, nameOffset, nameLength); - } else { + } + else { assert parseName(buffer, nameOffset, nameLength).equals(names[i]) : parseName(buffer, nameOffset, nameLength) + "!=" + names[i]; } @@ -186,15 +192,17 @@ public class CalculateAverage_ddimtirov { int mmcIndex = i * 3; var min = minMaxCount[mmcIndex + OFFSET_MIN]; var max = minMaxCount[mmcIndex + OFFSET_MAX]; - if (temperature < min) minMaxCount[mmcIndex + OFFSET_MIN] = temperature; - if (temperature > max) minMaxCount[mmcIndex + OFFSET_MAX] = temperature; + if (temperature < min) + minMaxCount[mmcIndex + OFFSET_MIN] = temperature; + if (temperature > max) + minMaxCount[mmcIndex + OFFSET_MAX] = temperature; minMaxCount[mmcIndex + OFFSET_COUNT]++; } - private String parseName(MappedByteBuffer buffer, int nameOffset, int nameLength) { - buffer.get(nameOffset, nameThreadLocal, 0, nameLength); - return new String(nameThreadLocal, 0, nameLength, StandardCharsets.UTF_8); + private String parseName(MemorySegment memory, int nameOffset, int nameLength) { + byte[] array = memory.asSlice(nameOffset, nameLength).toArray(ValueLayout.JAVA_BYTE); + return new String(array, StandardCharsets.UTF_8); } } }