From f409fe0815c18e0d79bc161b1f8d3baeb2ad5771 Mon Sep 17 00:00:00 2001 From: Juan Parera <1420988+jparera@users.noreply.github.com> Date: Fri, 19 Jan 2024 22:06:48 +0100 Subject: [PATCH] Change data storage improving memory locality (#496) --- .../onebrc/CalculateAverage_jparera.java | 235 ++++++++++-------- 1 file changed, 133 insertions(+), 102 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java index 1325255..194dbcc 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java @@ -1,3 +1,5 @@ +//COMPILE_OPTIONS -source 21 --enable-preview --add-modules jdk.incubator.vector +//RUNTIME_OPTIONS --enable-preview --add-modules jdk.incubator.vector /* * Copyright 2023 The original authors * @@ -19,6 +21,8 @@ import java.io.IOException; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; import java.nio.ByteOrder; import java.nio.channels.FileChannel; import java.nio.channels.FileChannel.MapMode; @@ -26,7 +30,6 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.TreeMap; import java.util.function.Function; @@ -34,25 +37,41 @@ import java.util.stream.Collectors; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.VectorSpecies; +import jdk.incubator.vector.VectorOperators; public class CalculateAverage_jparera { private static final String FILE = "./measurements.txt"; - private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_PREFERRED; + private static final VarHandle BYTE_HANDLE = MethodHandles + .memorySegmentViewVarHandle(ValueLayout.JAVA_BYTE); - private static final int BYTE_SPECIES_SIZE = BYTE_SPECIES.vectorByteSize(); + private static final VarHandle INT_HANDLE = MethodHandles + .memorySegmentViewVarHandle(ValueLayout.JAVA_INT_UNALIGNED); + + private static final VarHandle LONG_LE_HANDLE = MethodHandles + .memorySegmentViewVarHandle(ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN)); + + private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_PREFERRED; private static final int BYTE_SPECIES_LANES = BYTE_SPECIES.length(); - private static final ValueLayout.OfLong LONG_U_LE = ValueLayout.JAVA_LONG_UNALIGNED - .withOrder(ByteOrder.LITTLE_ENDIAN); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); - public static void main(String[] args) throws IOException { + private static final byte LF = '\n'; + + private static final byte SEPARATOR = ';'; + + private static final byte DECIMAL_SEPARATOR = '.'; + + private static final byte NEG = '-'; + + public static void main(String[] args) throws IOException, InterruptedException { try (var fc = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { try (var arena = Arena.ofShared()) { var fs = fc.map(MapMode.READ_ONLY, 0, fc.size(), arena); - var map = chunks(fs) - .parallelStream() + var cpus = Runtime.getRuntime().availableProcessors(); + var output = chunks(fs, cpus).stream() + .parallel() .map(Chunk::parse) .flatMap(List::stream) .collect(Collectors.toMap( @@ -60,20 +79,19 @@ public class CalculateAverage_jparera { Function.identity(), Entry::merge, TreeMap::new)); - System.out.println(map); + System.out.println(output); } } } - private static Collection chunks(MemorySegment ms) { - var cpus = Runtime.getRuntime().availableProcessors(); - long expectedChunkSize = Math.ceilDiv(ms.byteSize(), cpus); - var chunks = new ArrayList(); + private static List chunks(MemorySegment ms, int splits) { long fileSize = ms.byteSize(); + long expectedChunkSize = Math.ceilDiv(fileSize, splits); + var chunks = new ArrayList(); long offset = 0; while (offset < fileSize) { var end = Math.min(offset + expectedChunkSize, fileSize); - while (end < fileSize && ms.get(ValueLayout.JAVA_BYTE, end++) != '\n') { + while (end < fileSize && (byte) BYTE_HANDLE.get(ms, end++) != LF) { } long len = end - offset; chunks.add(new Chunk(ms.asSlice(offset, len))); @@ -83,25 +101,27 @@ public class CalculateAverage_jparera { } private static final class Chunk { - private static final byte SEPARATOR = ';'; - - private static final byte DECIMAL_SEPARATOR = '.'; - - private static final byte LF = '\n'; - - private static final byte MINUS = '-'; - private static final int KEY_LOG2_BYTES = 7; private static final int KEY_BYTES = 1 << KEY_LOG2_BYTES; - private static final int MAP_CAPACITY = 1 << 16; + private static final int ENTRIES_LOG2_CAPACITY = 16; - private static final int BUCKET_MASK = MAP_CAPACITY - 1; + private static final int ENTRIES_CAPACITY = 1 << ENTRIES_LOG2_CAPACITY; + + private static final int ENTRIES_MASK = ENTRIES_CAPACITY - 1; private final MemorySegment segment; - private final Entry[] entries = new Entry[MAP_CAPACITY]; + private final long size; + + private final Entry[] entries = new Entry[ENTRIES_CAPACITY]; + + private final byte[] keys = new byte[ENTRIES_CAPACITY * KEY_BYTES]; + + private final MemorySegment kms = MemorySegment.ofArray(this.keys); + + private static final int KEYS_MASK = (ENTRIES_CAPACITY * KEY_BYTES) - 1; private long offset; @@ -111,26 +131,23 @@ public class CalculateAverage_jparera { Chunk(MemorySegment segment) { this.segment = segment; + this.size = segment.byteSize(); } public List parse() { - long size = this.segment.byteSize(); long safe = size - KEY_BYTES; while (offset < safe) { - var e = vectorizedEntry(); - int value = vectorizedValue(); - e.add(value); + vectorizedEntry().add(vectorizedValue()); } next(); while (hasCurrent()) { - var e = entry(); - int value = value(); - e.add(value); + entry().add(value()); } var output = new ArrayList(entries.length); - for (int i = 0; i < entries.length; i++) { + for (int i = 0, o = 0; i < entries.length; i++, o += KEY_BYTES) { var e = entries[i]; if (e != null) { + e.setkey(keys, o); output.add(e); } } @@ -138,29 +155,48 @@ public class CalculateAverage_jparera { } private Entry vectorizedEntry() { - var start = this.offset; - var first = ByteVector.fromMemorySegment(BYTE_SPECIES, this.segment, start, ByteOrder.nativeOrder()); - int equals = first.eq(SEPARATOR).firstTrue(); - int len = equals; - for (int i = BYTE_SPECIES_SIZE; equals == BYTE_SPECIES_LANES; i += BYTE_SPECIES_SIZE) { - var next = ByteVector.fromMemorySegment(BYTE_SPECIES, this.segment, start + i, ByteOrder.nativeOrder()); - equals = next.eq(SEPARATOR).firstTrue(); + var separators = ByteVector.broadcast(BYTE_SPECIES, SEPARATOR); + int len = 0; + for (int i = 0;; i += BYTE_SPECIES_LANES) { + var block = ByteVector.fromMemorySegment(BYTE_SPECIES, this.segment, offset + i, NATIVE_ORDER); + int equals = block.compare(VectorOperators.EQ, separators).firstTrue(); len += equals; + if (equals != BYTE_SPECIES_LANES) { + break; + } } + var start = this.offset; this.offset = start + len + 1; - int index = hash(this.segment, start, len); + int hash = hash(segment, start, len); + int index = (hash - (hash >>> -ENTRIES_LOG2_CAPACITY)) & ENTRIES_MASK; + int keyOffset = index << KEY_LOG2_BYTES; int count = 0; - while (count < BUCKET_MASK) { - index = index & BUCKET_MASK; + while (count < ENTRIES_MASK) { + index = index & ENTRIES_MASK; + keyOffset = keyOffset & KEYS_MASK; var e = this.entries[index]; if (e == null) { - return this.entries[index] = new Entry(len, this.segment.asSlice(start, KEY_BYTES)); + MemorySegment.copy(this.segment, start, kms, keyOffset, len); + return this.entries[index] = new Entry(len, hash); } - else if (e.keyLength() == len && vectorizedEquals(e, first, start, len)) { - return e; + else if (e.hash == hash && e.keyLength == len) { + int total = 0; + for (int i = 0; i < KEY_BYTES; i += BYTE_SPECIES_LANES) { + var ekey = ByteVector.fromArray(BYTE_SPECIES, keys, keyOffset + i); + var okey = ByteVector.fromMemorySegment(BYTE_SPECIES, this.segment, start + i, NATIVE_ORDER); + int equals = ekey.compare(VectorOperators.NE, okey).firstTrue(); + total += equals; + if (equals != BYTE_SPECIES_LANES) { + break; + } + } + if (total >= len) { + return e; + } } - index++; count++; + index++; + keyOffset += KEY_BYTES; } throw new IllegalStateException("Map is full!"); } @@ -173,19 +209,33 @@ public class CalculateAverage_jparera { next(); } expect(SEPARATOR); - int index = hash(segment, start, len); + int hash = hash(segment, start, len); + int index = (hash - (hash >>> -ENTRIES_LOG2_CAPACITY)) & ENTRIES_MASK; + int keyOffset = index << KEY_LOG2_BYTES; int count = 0; - while (count < BUCKET_MASK) { - index = index & BUCKET_MASK; + while (count < ENTRIES_MASK) { + index = index & ENTRIES_MASK; + keyOffset = keyOffset & KEYS_MASK; var e = this.entries[index]; if (e == null) { - return this.entries[index] = new Entry(len, this.segment.asSlice(start, len)); + MemorySegment.copy(this.segment, start, kms, keyOffset, len); + return this.entries[index] = new Entry(len, hash); } - else if (e.keyLength() == len && equals(e, start, len)) { - return e; + else if (e.hash == hash && e.keyLength == len) { + int total = 0; + for (int i = 0; i < len; i++) { + if (((byte) BYTE_HANDLE.get(this.segment, start + i)) != this.keys[keyOffset + i]) { + break; + } + total++; + } + if (total >= len) { + return e; + } } - index++; count++; + index++; + keyOffset += KEY_BYTES; } throw new IllegalStateException("Map is full!"); } @@ -193,9 +243,9 @@ public class CalculateAverage_jparera { private static final long MULTIPLY_ADD_DIGITS = 100 * (1L << 24) + 10 * (1L << 16) + 1; private int vectorizedValue() { - long dw = this.segment.get(LONG_U_LE, this.offset); - boolean negative = ((dw & 0xFF) ^ MINUS) == 0; + long dw = (long) LONG_LE_HANDLE.get(this.segment, this.offset); int zeros = Long.numberOfTrailingZeros(~dw & 0x10101000L); + boolean negative = ((dw & 0xFF) ^ NEG) == 0; dw = ((negative ? (dw & ~0xFF) : dw) << (28 - zeros)) & 0x0F000F0F00L; int value = (int) (((dw * MULTIPLY_ADD_DIGITS) >>> 32) & 0x3FF); this.offset += (zeros >>> 3) + 3; @@ -205,7 +255,7 @@ public class CalculateAverage_jparera { private int value() { int value = 0; var negative = false; - if (consume(MINUS)) { + if (consume(NEG)) { negative = true; } while (hasCurrent()) { @@ -224,41 +274,18 @@ public class CalculateAverage_jparera { return negative ? -value : value; } - private boolean vectorizedEquals(Entry entry, ByteVector okey, long offset, int len) { - var ekey = ByteVector.fromMemorySegment(BYTE_SPECIES, entry.segment(), 0, ByteOrder.nativeOrder()); - int equals = ekey.eq(okey).not().firstTrue(); - if (equals != BYTE_SPECIES_LANES) { - return equals >= len; - } - long eo = BYTE_SPECIES_SIZE; - int total = BYTE_SPECIES_LANES; - while (equals == BYTE_SPECIES_LANES & eo < KEY_BYTES) { - offset += BYTE_SPECIES_SIZE; - ekey = ByteVector.fromMemorySegment(BYTE_SPECIES, entry.segment(), eo, ByteOrder.nativeOrder()); - okey = ByteVector.fromMemorySegment(BYTE_SPECIES, segment, offset, ByteOrder.nativeOrder()); - equals = ekey.eq(okey).not().firstTrue(); - total += equals; - eo += BYTE_SPECIES_SIZE; - } - return total >= len; - } - - private boolean equals(Entry entry, long offset, int len) { - return MemorySegment.mismatch(this.segment, offset, offset + len, entry.segment(), 0, len) == -1; - } - private static final int GOLDEN_RATIO = 0x9E3779B9; private static final int HASH_LROTATE = 5; private static int hash(MemorySegment ms, long start, int len) { int x, y; if (len >= Integer.BYTES) { - x = ms.get(ValueLayout.JAVA_INT_UNALIGNED, start); - y = ms.get(ValueLayout.JAVA_INT_UNALIGNED, start + len - Integer.BYTES); + x = (int) INT_HANDLE.get(ms, start); + y = (int) INT_HANDLE.get(ms, start + len - Integer.BYTES); } else { - x = ms.get(ValueLayout.JAVA_BYTE, start); - y = ms.get(ValueLayout.JAVA_BYTE, start + len - Byte.BYTES); + x = (byte) BYTE_HANDLE.get(ms, start) & 0xFF; + y = (byte) BYTE_HANDLE.get(ms, start + len - Byte.BYTES) & 0xFF; } return (Integer.rotateLeft(x * GOLDEN_RATIO, HASH_LROTATE) ^ y) * GOLDEN_RATIO; } @@ -282,8 +309,8 @@ public class CalculateAverage_jparera { } private void next() { - if (offset < segment.byteSize()) { - this.current = segment.get(ValueLayout.JAVA_BYTE, offset++); + if (offset < size) { + this.current = (byte) BYTE_HANDLE.get(segment, offset++); } else { this.hasCurrent = false; @@ -292,9 +319,9 @@ public class CalculateAverage_jparera { } private static final class Entry { - private final int keyLength; + final int keyLength; - private final MemorySegment segment; + final int hash; private int min = Integer.MAX_VALUE; @@ -304,21 +331,19 @@ public class CalculateAverage_jparera { private int count; - Entry(int keyLength, MemorySegment segment) { + private String key; + + Entry(int keyLength, int hash) { this.keyLength = keyLength; - this.segment = segment; - } - - int keyLength() { - return keyLength; - } - - MemorySegment segment() { - return segment; + this.hash = hash; } public String key() { - return new String(segment.asSlice(0, keyLength).toArray(ValueLayout.JAVA_BYTE), StandardCharsets.UTF_8); + return key; + } + + void setkey(byte[] keys, int offset) { + this.key = new String(keys, offset, keyLength, StandardCharsets.UTF_8); } public void add(int value) { @@ -339,13 +364,19 @@ public class CalculateAverage_jparera { @Override public String toString() { var average = Math.round(((sum / 10.0) / count) * 10.0); - return decimal(min) + "/" + decimal(average) + "/" + decimal(max); + return decimal(min) + '/' + decimal(average) + '/' + decimal(max); } private static String decimal(long value) { - boolean negative = value < 0; + var builder = new StringBuilder(); + if (value < 0) { + builder.append((char) NEG); + } value = Math.abs(value); - return (negative ? "-" : "") + (value / 10) + "." + (value % 10); + builder.append(value / 10); + builder.append((char) DECIMAL_SEPARATOR); + builder.append(value % 10); + return builder.toString(); } } }