From 3cc4fc85d83122eba8944036691d00e195b6aa57 Mon Sep 17 00:00:00 2001 From: Peter Levart Date: Wed, 31 Jan 2024 18:07:56 +0100 Subject: [PATCH] update1: restructuring for better compilation (#661) --- calculate_average_plevart.sh | 1 + prepare_plevart.sh | 2 +- .../onebrc/CalculateAverage_plevart.java | 167 +++++++++--------- 3 files changed, 84 insertions(+), 86 deletions(-) diff --git a/calculate_average_plevart.sh b/calculate_average_plevart.sh index be195ac..32cee48 100755 --- a/calculate_average_plevart.sh +++ b/calculate_average_plevart.sh @@ -17,6 +17,7 @@ JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector" JAVA_OPTS="$JAVA_OPTS -XX:-TieredCompilation" +JAVA_OPTS="$JAVA_OPTS -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields" JAVA_OPTS="$JAVA_OPTS -XX:InlineSmallCode=15000 -XX:FreqInlineSize=400 -XX:MaxInlineSize=400" #JAVA_OPTS="$JAVA_OPTS -XX:+PrintCompilation -XX:+UnlockDiagnosticVMOptions -XX:+PrintInlining" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_plevart $* diff --git a/prepare_plevart.sh b/prepare_plevart.sh index d2a3c6b..5259fbe 100755 --- a/prepare_plevart.sh +++ b/prepare_plevart.sh @@ -16,4 +16,4 @@ # source "$HOME/.sdkman/bin/sdkman-init.sh" -sdk use java 21.0.1-tem 1>&2 +sdk use java 21.0.2-tem 1>&2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java b/src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java index fd42d45..80c9e89 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java @@ -29,6 +29,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.Comparator; +import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -43,9 +44,10 @@ public class CalculateAverage_plevart { private static final int INITIAL_TABLE_CAPACITY = 8192; public static void main(String[] args) throws IOException { - var arena = Arena.global(); + System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0"); try ( - var channel = (FileChannel) Files.newByteChannel(FILE, StandardOpenOption.READ)) { + var channel = (FileChannel) Files.newByteChannel(FILE, StandardOpenOption.READ); + var arena = Arena.ofShared()) { var segment = channel.map(FileChannel.MapMode.READ_ONLY, 0, Files.size(FILE), arena); int regions = Runtime.getRuntime().availableProcessors(); IntStream @@ -54,7 +56,6 @@ public class CalculateAverage_plevart { .mapToObj(r -> calculateRegion(segment, regions, r)) .reduce(StatsTable::reduce) .ifPresent(System.out::println); - segment.unload(); } } @@ -68,14 +69,12 @@ public class CalculateAverage_plevart { end = skipPastNl(segment, end); } - var stats = new StatsTable(segment, INITIAL_TABLE_CAPACITY); - calculateAdjustedRegion(segment, start, end, stats); - return stats; + return calculateAdjustedRegion(segment, start, end); } private static long skipPastNl(MemorySegment segment, long i) { int skipped = 0; - while (skipped++ < MAX_LINE_LEN && getByte(segment, i++) != '\n') { + while (skipped++ < MAX_LINE_LEN && segment.get(ValueLayout.JAVA_BYTE, i++) != '\n') { } if (skipped > MAX_LINE_LEN) { throw new IllegalArgumentException( @@ -84,27 +83,28 @@ public class CalculateAverage_plevart { return i; } - private static void calculateAdjustedRegion(MemorySegment segment, long start, long end, StatsTable stats) { + private static StatsTable calculateAdjustedRegion(MemorySegment segment, long start, long end) { + var stats = new StatsTable(segment, INITIAL_TABLE_CAPACITY); + var species = ByteVector.SPECIES_PREFERRED; - long speciesByteSize = species.vectorByteSize(); long cityStart = start, numberStart = 0; int cityLen = 0; for (long i = start, j = i; i < end; j = i) { long semiNlSet; - if (end - i >= speciesByteSize) { + if (end - i >= species.vectorByteSize()) { var vec = ByteVector.fromMemorySegment(species, segment, i, ByteOrder.nativeOrder()); semiNlSet = vec.compare(VectorOperators.EQ, (byte) ';') .or(vec.compare(VectorOperators.EQ, (byte) '\n')) .toLong(); - i += speciesByteSize; + i += species.vectorByteSize(); } else { // tail, smaller than speciesByteSize semiNlSet = 0; long mask = 1; while (i < end && mask != 0) { - int c = getByte(segment, i++); + int c = segment.get(ValueLayout.JAVA_BYTE, i++); if (c == '\n' || c == ';') { semiNlSet |= mask; } @@ -120,63 +120,17 @@ public class CalculateAverage_plevart { } else { // nl int numberLen = (int) (j - numberStart); - calculateEntry(segment, cityStart, cityLen, numberStart, numberLen, stats); + stats.calculateEntry(cityStart, cityLen, numberStart, numberLen); cityStart = ++j; numberStart = 0; } } } + + return stats; } - private static void calculateEntry(MemorySegment segment, long cityStart, int cityLen, long numberStart, int numberLen, StatsTable stats) { - int hash = StatsTable.hash(segment, cityStart, cityLen); - int number = parseNumber(segment, numberStart, numberLen); - stats.aggregate(cityStart, cityLen, hash, 1, number, number, number); - } - - private static int parseNumber(MemorySegment segment, long off, int len) { - int c0 = getByte(segment, off); - int d0; - int sign; - if (c0 == '-') { - off++; - len--; - d0 = getByte(segment, off) - '0'; - sign = -1; - } else { - d0 = c0 - '0'; - sign = 1; - } - return sign * switch (len) { - case 1 -> d0 * 10; // 9 - case 2 -> { - int d1 = getByte(segment, off + 1) - '0'; - yield d0 * 100 + d1 * 10; // 99 - } - case 3 -> { - int d2 = getByte(segment, off + 2) - '0'; - yield d0 * 10 + d2; // 9.9 - } - case 4 -> { - int d1 = getByte(segment, off + 1) - '0'; - int d3 = getByte(segment, off + 3) - '0'; - yield d0 * 100 + d1 * 10 + d3; // 99.9 - } - default -> { - throw new IllegalArgumentException("Invalid number: " + getString(segment, off, len)); - } - }; - } - - private static int getByte(MemorySegment segment, long off) { - return segment.get(ValueLayout.JAVA_BYTE, off); - } - - private static String getString(MemorySegment segment, long off, int len) { - return new String(segment.asSlice(off, len).toArray(ValueLayout.JAVA_BYTE), StandardCharsets.UTF_8); - } - - final static class StatsTable implements Cloneable { + final static class StatsTable { private static final int LOAD_FACTOR = 16; // offsets of fields private static final int _lenHash = 0, @@ -190,7 +144,7 @@ public class CalculateAverage_plevart { private long[] table; StatsTable(MemorySegment segment, int capacity) { - this.segment = segment; + this.segment = Objects.requireNonNull(segment); int pow2cap = Integer.highestOneBit(capacity); if (pow2cap < capacity) { pow2cap <<= 1; @@ -199,6 +153,13 @@ public class CalculateAverage_plevart { this.table = new long[idx(pow2cap)]; } + private StatsTable(StatsTable st) { + this.segment = st.segment; + this.pow2cap = st.pow2cap; + this.loadedSize = st.loadedSize; + this.table = st.table; + } + private static int idx(int i) { return i << 3; } @@ -237,7 +198,49 @@ public class CalculateAverage_plevart { } } - static int hash(MemorySegment segment, long off, int len) { + void calculateEntry(long cityStart, int cityLen, long numberStart, int numberLen) { + int hash = hash(cityStart, cityLen); + int number = parseNumber(numberStart, numberLen); + aggregate(cityStart, cityLen, hash, 1, number, number, number); + } + + int parseNumber(long off, int len) { + int c0 = segment.get(ValueLayout.JAVA_BYTE, off); + int d0; + int sign; + if (c0 == '-') { + off++; + len--; + d0 = segment.get(ValueLayout.JAVA_BYTE, off) - '0'; + sign = -1; + } else { + d0 = c0 - '0'; + sign = 1; + } + return sign * switch (len) { + case 1 -> d0 * 10; // 9 + case 2 -> { + int d1 = segment.get(ValueLayout.JAVA_BYTE, off + 1) - '0'; + yield d0 * 100 + d1 * 10; // 99 + } + case 3 -> { + int d2 = segment.get(ValueLayout.JAVA_BYTE, off + 2) - '0'; + yield d0 * 10 + d2; // 9.9 + } + case 4 -> { + int d1 = segment.get(ValueLayout.JAVA_BYTE, off + 1) - '0'; + int d3 = segment.get(ValueLayout.JAVA_BYTE, off + 3) - '0'; + yield d0 * 100 + d1 * 10 + d3; // 99.9 + } + default -> + throw new IllegalArgumentException( + "Invalid number: " + + new String(segment.asSlice(off, len).toArray(ValueLayout.JAVA_BYTE), StandardCharsets.UTF_8) + ); + }; + } + + int hash(long off, int len) { if (len > Integer.BYTES) { int head = segment.get(ValueLayout.JAVA_INT_UNALIGNED, off); int tail = segment.get(ValueLayout.JAVA_INT_UNALIGNED, off + len - Integer.BYTES); @@ -251,7 +254,11 @@ public class CalculateAverage_plevart { } } - static boolean equals(MemorySegment segment, long off1, long off2, int len) { + private static boolean bothLessThan(long a, long b, long threshold) { + return (a < threshold) && (b < threshold); + } + + boolean equals(long off1, long off2, int len) { while (len >= Long.BYTES) { if (segment.get(ValueLayout.JAVA_LONG_UNALIGNED, off1) != segment.get(ValueLayout.JAVA_LONG_UNALIGNED, off2)) { return false; @@ -261,16 +268,16 @@ public class CalculateAverage_plevart { len -= Long.BYTES; } // still enough memory to compare two longs, but masked? - if (Math.max(off1, off2) + Long.BYTES <= segment.byteSize()) { + if (bothLessThan(off1, off2, segment.byteSize() - Long.BYTES + 1)) { long mask = LEN_LONG_MASK[len]; return (segment.get(ValueLayout.JAVA_LONG_UNALIGNED, off1) & mask) == (segment.get(ValueLayout.JAVA_LONG_UNALIGNED, off2) & mask); } else { - return equalsAtBorder(segment, off1, off2, len); + return equalsAtBorder(off1, off2, len); } } - private static boolean equalsAtBorder(MemorySegment segment, long off1, long off2, int len) { + private boolean equalsAtBorder(long off1, long off2, int len) { if (len > Integer.BYTES) { if (segment.get(ValueLayout.JAVA_INT_UNALIGNED, off1) != segment.get(ValueLayout.JAVA_INT_UNALIGNED, off2)) { return false; @@ -290,7 +297,7 @@ public class CalculateAverage_plevart { // key long off, int len, int hash, // value - long count, long sum, long min, long max) { + long count, long sum, int min, int max) { long lenHash = lenHash(len, hash); int mask = pow2cap - 1; for (int i = hash & mask, probe = 0; probe < pow2cap; i = (i + 1) & mask, probe++) { @@ -309,11 +316,11 @@ public class CalculateAverage_plevart { } return; } - if (lenHash_i == lenHash && equals(segment, table[idx + _off], off, len)) { + if (lenHash_i == lenHash && equals(off, table[idx + _off], len)) { table[idx + _count] += count; table[idx + _sum] += sum; - table[idx + _min] = Math.min(min, table[idx + _min]); - table[idx + _max] = Math.max(max, table[idx + _max]); + table[idx + _min] = Math.min(min, (int) table[idx + _min]); + table[idx + _max] = Math.max(max, (int) table[idx + _max]); return; } } @@ -325,7 +332,7 @@ public class CalculateAverage_plevart { throw new OutOfMemoryError("StatsTable capacity exceeded"); } else { - var oldStats = clone(); + var oldStats = new StatsTable(this); pow2cap <<= 1; table = new long[idx(pow2cap)]; loadedSize = 0; @@ -333,16 +340,6 @@ public class CalculateAverage_plevart { } } - @Override - protected StatsTable clone() { - try { - return (StatsTable) super.clone(); - } - catch (CloneNotSupportedException e) { - throw new InternalError(e); - } - } - StatsTable reduce(StatsTable other) { other .idxStream() @@ -353,8 +350,8 @@ public class CalculateAverage_plevart { hash(other.table[idx + _lenHash]), other.table[idx + _count], other.table[idx + _sum], - other.table[idx + _min], - other.table[idx + _max])); + (int) other.table[idx + _min], + (int) other.table[idx + _max])); return this; }