From be5b3318b150243c0ea24def82c2b8d6f3f1e06d Mon Sep 17 00:00:00 2001 From: giovannicuccu Date: Sun, 28 Jan 2024 23:24:47 +0100 Subject: [PATCH] Solution without unsafe using vector API (#602) * Solution without unsafe * Solution without unsafe * Solution without unsafe, remove the usage of bytebuffer, passes the create_measurements3 test * bug fix for 10k test, update also the CreateMeasurements3.java to use '\n' as newline instead of the os value (if it runs on windows it uses crlf and "breaks" the file format ) --------- Co-authored-by: Giovanni Cuccu --- calculate_average_giovannicuccu.sh | 2 +- github_users.txt | 1 + prepare_giovannicuccu.sh | 0 .../CalculateAverage_giovannicuccu.java | 413 ++++++++++-------- .../morling/onebrc/CreateMeasurements3.java | 2 +- 5 files changed, 232 insertions(+), 186 deletions(-) mode change 100755 => 100644 calculate_average_giovannicuccu.sh mode change 100755 => 100644 prepare_giovannicuccu.sh diff --git a/calculate_average_giovannicuccu.sh b/calculate_average_giovannicuccu.sh old mode 100755 new mode 100644 index 314b5d8..2188385 --- a/calculate_average_giovannicuccu.sh +++ b/calculate_average_giovannicuccu.sh @@ -15,5 +15,5 @@ # limitations under the License. # -JAVA_OPTS="" +JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector -XX:-TieredCompilation" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_giovannicuccu diff --git a/github_users.txt b/github_users.txt index 497909c..eb3ac2c 100644 --- a/github_users.txt +++ b/github_users.txt @@ -1,3 +1,4 @@ +giovannicuccu;Giovanni Cuccu Ujjwalbharti;Ujjwal Bharti abfrmblr;Abhilash ags313;ags diff --git a/prepare_giovannicuccu.sh b/prepare_giovannicuccu.sh old mode 100755 new mode 100644 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java b/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java index 7b549dc..7123c2c 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java @@ -15,10 +15,19 @@ */ package dev.morling.onebrc; +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; + import static java.util.stream.Collectors.*; import java.io.IOException; import java.io.RandomAccessFile; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; @@ -31,34 +40,42 @@ import java.util.*; import java.util.concurrent.*; /* - Solution without unsafe that borrows the ideas of splullara, thomasvue, royvanrijn + Solution without unsafe that borrows the ideas of splullara, thomasvue, royvanrijn and merykitty */ public class CalculateAverage_giovannicuccu { private static final String FILE = "./measurements.txt"; - public static record PartitionBoundary(long start, long end) { + private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_256; + private static final int BYTE_SPECIES_LANES = BYTE_SPECIES.length(); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); + public static final VectorSpecies INT_SPECIES = IntVector.SPECIES_256; + public static final int INT_SPECIES_LANES = INT_SPECIES.length(); + + public static final int KEY_SIZE = 128; + + public static record PartitionBoundary(Path path, long start, long end) { } public static interface PartitionCalculator { - PartitionBoundary[] computePartitionsBoundaries(Path path); + List computePartitionsBoundaries(Path path); } public static class ProcessorPartitionCalculator implements PartitionCalculator { - public PartitionBoundary[] computePartitionsBoundaries(Path path) { + public List computePartitionsBoundaries(Path path) { try { int numberOfSegments = Runtime.getRuntime().availableProcessors(); long fileSize = path.toFile().length(); long segmentSize = fileSize / numberOfSegments; - PartitionBoundary[] segmentBoundaries = new PartitionBoundary[numberOfSegments]; + List segmentBoundaries = new ArrayList<>(numberOfSegments); try (RandomAccessFile randomAccessFile = new RandomAccessFile(path.toFile(), "r")) { long segStart = 0; long segEnd = segmentSize; for (int i = 0; i < numberOfSegments; i++) { segEnd = findEndSegment(randomAccessFile, segEnd, fileSize); - segmentBoundaries[i] = new PartitionBoundary(segStart, segEnd); + segmentBoundaries.add(new PartitionBoundary(path, segStart, segEnd)); segStart = segEnd; segEnd = Math.min(segEnd + segmentSize, fileSize); } @@ -81,51 +98,27 @@ public class CalculateAverage_giovannicuccu { } } - public static class MeasurementAggregator { - private final int hash; + private static class MeasurementAggregatorVectorized { + private int min; private int max; private double sum; private long count; - private final byte[] station; + private final int len; + private final int hash; + private final int offset; - private final String name; + private byte[] data; - private final long[] data; - private final int dataOffset; - - public MeasurementAggregator(byte[] station, int offset, int hash, int initialValue, long[] data, int dataOffset) { + public MeasurementAggregatorVectorized(byte[] data, int offset, int len, int hash, int initialValue) { min = initialValue; max = initialValue; sum = initialValue; count = 1; - this.station = station; - this.offset = offset; + this.len = len; this.hash = hash; + this.offset = offset; this.data = data; - this.dataOffset = dataOffset; - this.name = new String(station, 0, offset, StandardCharsets.UTF_8); - } - - public MeasurementAggregator(byte[] station, int offset, int hash, int initialValue) { - min = initialValue; - max = initialValue; - sum = initialValue; - count = 1; - this.station = station; - this.offset = offset; - this.hash = hash; - this.data = new long[0]; - this.dataOffset = 0; - this.name = new String(station, 0, offset, StandardCharsets.UTF_8); - } - - public boolean hasSameStation(byte[] stationIn, int offsetIn) { - return Arrays.equals(stationIn, 0, offsetIn, station, 0, offset); - } - - public boolean hasSameStation(long[] dataIn, int offsetIn) { - return Arrays.equals(dataIn, 0, offsetIn, data, 0, dataOffset); } public void add(int value) { @@ -139,8 +132,7 @@ public class CalculateAverage_giovannicuccu { count++; } - public void merge(MeasurementAggregator other) { - // System.out.println("min=" +min + " other min=" +other.min); + public void merge(MeasurementAggregatorVectorized other) { min = Math.min(min, other.min); max = Math.max(max, other.max); sum += other.sum; @@ -149,7 +141,7 @@ public class CalculateAverage_giovannicuccu { @Override public String toString() { - return round((double) min / 10) + "/" + round((sum / (double) count) / 10) + "/" + round((double) max / 10); + return round(min / 10.) + "/" + round(sum / (double) (10 * count)) + "/" + round(max / 10.); } private double round(double value) { @@ -164,116 +156,141 @@ public class CalculateAverage_giovannicuccu { return hash; } - public String getName() { - return name; + public int getLen() { + return len; } - public byte[] getStation() { - return station; + public boolean dataEquals(byte[] data, int offset) { + return Arrays.equals(this.data, this.offset, this.offset + len, data, offset, offset + len); + + } + + public String getName() { + return new String(data, offset, len, StandardCharsets.UTF_8); } public int getOffset() { return offset; } - public long[] getData() { + public byte[] getData() { return data; } - } - public static class MeasurementList { - + private static class MeasurementListVectorized { private static final int SIZE = 1024 * 64; - private final MeasurementAggregator[] measurements = new MeasurementAggregator[SIZE]; + private final MeasurementAggregatorVectorized[] measurements = new MeasurementAggregatorVectorized[SIZE]; + private final byte[] keyData = new byte[SIZE * KEY_SIZE]; - public void add(byte[] station, int offset, int hash, int value) { + private final MemorySegment dataSegment = MemorySegment.ofArray(keyData); + + public void addWithByteVector(ByteVector chunk1, int len, int hash, int value, MemorySegment memorySegment, long offset) { int index = hash & (SIZE - 1); - if (measurements[index] == null) { - measurements[index] = new MeasurementAggregator(station.clone(), offset, hash, value); - } - else { - if (measurements[index].hasSameStation(station, offset)) { - measurements[index].add(value); - } - else { - while (measurements[index] != null && !measurements[index].hasSameStation(station, offset)) { - index = (index + 1) & (SIZE - 1); - } - if (measurements[index] == null) { - measurements[index] = new MeasurementAggregator(station.clone(), offset, hash, value); - } - else { + int i = 0; + while (measurements[index] != null) { + if (measurements[index].getLen() == len && measurements[index].getHash() == hash) { + var nodeKey = ByteVector.fromArray(BYTE_SPECIES, keyData, index * KEY_SIZE); + long eqMask = chunk1.compare(VectorOperators.EQ, nodeKey).toLong(); + long validMask = -1L >>> (64 - len); + if ((eqMask & validMask) == validMask) { measurements[index].add(value); + return; } } + index = (index + 1) & (SIZE - 1); } + MemorySegment.copy(memorySegment, offset, dataSegment, (long) index * KEY_SIZE, len); + measurements[index] = new MeasurementAggregatorVectorized(keyData, index * KEY_SIZE, len, hash, value); } - public void merge(MeasurementAggregator measurementAggregator) { - int index = (measurementAggregator.getHash() & (SIZE - 1)); - if (measurements[index] == null) { - measurements[index] = measurementAggregator; - } - else { - while (measurements[index] != null && !measurements[index].hasSameStation(measurementAggregator.getStation(), measurementAggregator.getOffset())) { - index = (index + 1) & (SIZE - 1); - } - if (measurements[index] == null) { - measurements[index] = measurementAggregator; - } - else { - measurements[index].merge(measurementAggregator); - } + public void add(int len, int hash, int value, MemorySegment memorySegment, long offset) { + int index = hash & (SIZE - 1); + while (measurements[index] != null) { + if (measurements[index].getLen() == len && measurements[index].getHash() == hash) { + int i = 0; + while (i < len && keyData[index * KEY_SIZE + i] == memorySegment.get(ValueLayout.JAVA_BYTE, offset + i)) { + i++; + } + if (i == len) { + measurements[index].add(value); + return; + } + } + index = (index + 1) & (SIZE - 1); } + MemorySegment.copy(memorySegment, offset, dataSegment, (long) index * KEY_SIZE, len); + measurements[index] = new MeasurementAggregatorVectorized(keyData, index * KEY_SIZE, len, hash, value); } - public MeasurementAggregator[] getMeasurements() { + public void merge(MeasurementAggregatorVectorized measurementAggregator) { + int index = measurementAggregator.getHash() & (SIZE - 1); + while (measurements[index] != null) { + if (measurements[index].getLen() == measurementAggregator.getLen() && measurements[index].getHash() == measurementAggregator.getHash()) { + if (measurementAggregator.dataEquals(measurements[index].getData(), measurements[index].getOffset())) { + measurements[index].merge(measurementAggregator); + return; + } + } + index = (index + 1) & (SIZE - 1); + } + measurements[index] = measurementAggregator; + } + + public MeasurementAggregatorVectorized[] getMeasurements() { return measurements; } + } - public static class MMapReader { + private static class MMapReaderMemorySegment { + private final Path path; - private final PartitionBoundary[] boundaries; - + private final List boundaries; private final boolean serial; + private static final byte SEPARATOR = ';'; + ByteVector separators = ByteVector.broadcast(BYTE_SPECIES, SEPARATOR); + private static final ValueLayout.OfLong JAVA_LONG_LT = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); - public MMapReader(Path path, PartitionCalculator partitionCalculator, boolean serial) { + public MMapReaderMemorySegment(Path path, PartitionCalculator partitionCalculator, boolean serial) { this.path = path; this.serial = serial; boundaries = partitionCalculator.computePartitionsBoundaries(path); } - public TreeMap elaborate() { - try (ExecutorService executor = Executors.newFixedThreadPool(boundaries.length)) { - List> futures = new ArrayList<>(); + public TreeMap elaborate() throws IOException { + try (ExecutorService executor = Executors.newFixedThreadPool(boundaries.size()); + FileChannel fileChannel = (FileChannel) Files.newByteChannel((path), StandardOpenOption.READ); + var arena = Arena.ofShared()) { + + List> futures = new ArrayList<>(); for (PartitionBoundary boundary : boundaries) { if (serial) { - FutureTask future = new FutureTask<>(() -> computeListForPartition(boundary.start(), boundary.end())); + FutureTask future = new FutureTask<>(() -> computeListForPartition( + fileChannel, boundary)); future.run(); - // System.out.println("done with partition " + boundary); futures.add(future); } else { - Future future = executor.submit(() -> computeListForPartition(boundary.start(), boundary.end())); + Future future = executor.submit(() -> computeListForPartition( + fileChannel, boundary)); futures.add(future); } } - TreeMap ris = reduce(futures); + TreeMap ris = reduce(futures); return ris; } } - private TreeMap reduce(List> futures) { + private TreeMap reduce(List> futures) { try { - TreeMap risMap = new TreeMap<>(); - MeasurementList ris = new MeasurementList(); - for (Future future : futures) { - MeasurementList results = future.get(); + TreeMap risMap = new TreeMap<>(); + MeasurementListVectorized ris = new MeasurementListVectorized(); + for (Future future : futures) { + MeasurementListVectorized results = future.get(); merge(ris, results); } - for (MeasurementAggregator m : ris.getMeasurements()) { + for (MeasurementAggregatorVectorized m : ris.getMeasurements()) { if (m != null) { risMap.put(m.getName(), m); } @@ -286,101 +303,134 @@ public class CalculateAverage_giovannicuccu { } } - private void merge(MeasurementList result, MeasurementList partial) { - for (MeasurementAggregator m : partial.getMeasurements()) { + private void merge(MeasurementListVectorized result, MeasurementListVectorized partial) { + for (MeasurementAggregatorVectorized m : partial.getMeasurements()) { if (m != null) { result.merge(m); } } } - private MeasurementList computeListForPartition(long start, long end) { - MeasurementList list = new MeasurementList(); - try { - try (FileChannel fileChannel = (FileChannel) Files.newByteChannel((path), StandardOpenOption.READ)) { - MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, start, end - start); - mappedByteBuffer.order(BYTE_ORDER.LITTLE_ENDIAN); - int limit = mappedByteBuffer.limit(); - int startLine; - byte[] stationb = new byte[100]; - while ((startLine = mappedByteBuffer.position()) < limit - 110) { - int currentPosition = startLine; - byte b = 0; - int i = 0; - int hash = 0; - - while ((b = mappedByteBuffer.get(currentPosition++)) != ';') { - stationb[i++] = b; - hash = 31 * hash + b; + private MeasurementListVectorized computeListForPartition(FileChannel fileChannel, PartitionBoundary boundary) { + try (var arena = Arena.ofConfined()) { + var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, boundary.start(), boundary.end() - boundary.start(), arena); + MeasurementListVectorized list = new MeasurementListVectorized(); + long size = memorySegment.byteSize(); + long offset = 0; + long safe = size - KEY_SIZE; + // ByteBuffer byteBuffer = memorySegment.asByteBuffer(); + // byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + ByteVector chunk1 = ByteVector.zero(BYTE_SPECIES); + ByteVector chunk2 = ByteVector.zero(BYTE_SPECIES); + while (offset < safe) { + int len = 0; + chunk1 = ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment, offset, NATIVE_ORDER); + int equals = chunk1.compare(VectorOperators.EQ, separators).firstTrue(); + len += equals; + if (equals == BYTE_SPECIES_LANES) { + while (memorySegment.get(ValueLayout.JAVA_BYTE, offset + len) != ';') { + len++; } - if (hash < 0) { - hash = -hash; - } - - long numberWord = mappedByteBuffer.getLong(currentPosition); - int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); - int value = convertIntoNumber(decimalSepPos, numberWord); - mappedByteBuffer.position(currentPosition + (decimalSepPos >>> 3) + 3); - - list.add(stationb, i, hash, value); - } - while ((startLine = mappedByteBuffer.position()) < limit) { - int currentPosition = startLine; - byte b = 0; - int i = 0; - int hash = 0; - while ((b = mappedByteBuffer.get(currentPosition++)) != ';') { - stationb[i++] = b; - hash = 31 * hash + b; - } - if (hash < 0) { - hash = -hash; - } - int value = 0; - if (currentPosition <= limit - 8) { - long numberWord = mappedByteBuffer.getLong(currentPosition); - int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); - value = convertIntoNumber(decimalSepPos, numberWord); - mappedByteBuffer.position(currentPosition + (decimalSepPos >>> 3) + 3); - } - else { - int sign = 1; - b = mappedByteBuffer.get(currentPosition++); - if (b == '-') { - sign = -1; - } - else { - value = b - '0'; - } - while ((b = mappedByteBuffer.get(currentPosition++)) != '.') { - value = value * 10 + (b - '0'); - } - b = mappedByteBuffer.get(currentPosition); - value = value * 10 + (b - '0'); - if (sign == -1) { - value = -value; - } - mappedByteBuffer.position(currentPosition + 2); - } + int hash = hash(memorySegment, offset, len); + long prevOffset = offset; + offset += len + 1; - list.add(stationb, i, hash, value); + long numberWord = memorySegment.get(JAVA_LONG_LT, offset); + int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); + int value = convertIntoNumber(decimalSepPos, numberWord); + offset += (decimalSepPos >>> 3) + 3; + // System.out.println("Value=" + value); + if (len < BYTE_SPECIES_LANES) { + list.addWithByteVector(chunk1, len, hash, value, memorySegment, prevOffset); + } + else { + list.add(len, hash, value, memorySegment, prevOffset); } } + + while (offset < size) { + int len = 0; + int equals = BYTE_SPECIES_LANES; + if (offset + BYTE_SPECIES_LANES < size) { + chunk1 = ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment, offset, NATIVE_ORDER); + equals = chunk1.compare(VectorOperators.EQ, separators).firstTrue(); + len += equals; + if (equals == BYTE_SPECIES_LANES) { + while (memorySegment.get(ValueLayout.JAVA_BYTE, offset + len) != ';') { + len++; + } + } + } + else { + byte[] bytes = new byte[BYTE_SPECIES_LANES]; + MemorySegment.copy(memorySegment, offset + len, MemorySegment.ofArray(bytes), 0, (size - offset - len)); + // byteBuffer.get(offset + len, bytes, 0, (int) (size - offset - len)); + chunk1 = ByteVector.fromArray(BYTE_SPECIES, bytes, 0); + equals = chunk1.compare(VectorOperators.EQ, separators).firstTrue(); + len += equals; + } + int hash = hash(memorySegment, offset, len); + long prevOffset = offset; + offset += len + 1; + + int value = 0; + if (offset < size - 8) { + long numberWord = memorySegment.get(JAVA_LONG_LT, offset); + int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); + value = convertIntoNumber(decimalSepPos, numberWord); + offset += (decimalSepPos >>> 3) + 3; + } + else { + long currentPosition = offset; + int sign = 1; + byte b = memorySegment.get(ValueLayout.JAVA_BYTE, currentPosition++); + if (b == '-') { + sign = -1; + } + else { + value = b - '0'; + } + while ((b = memorySegment.get(ValueLayout.JAVA_BYTE, currentPosition++)) != '.') { + value = value * 10 + (b - '0'); + } + b = memorySegment.get(ValueLayout.JAVA_BYTE, currentPosition); + value = value * 10 + (b - '0'); + if (sign == -1) { + value = -value; + } + offset = currentPosition + 2; + } + if (len < BYTE_SPECIES_LANES) { + list.addWithByteVector(chunk1, len, hash, value, memorySegment, prevOffset); + } + else { + list.add(len, hash, value, memorySegment, prevOffset); + } + } + return list; } catch (IOException e) { - System.out.println("Error"); - System.err.println(e); + throw new RuntimeException(e); } - return list; } - private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder(); + private static final int GOLDEN_RATIO = 0x9E3779B9; + private static final int HASH_LROTATE = 5; - private static long getLongLittleEndian(long value) { - value = Long.reverseBytes(value); - return value; + private static int hash(MemorySegment memorySegment, long start, int len) { + int x; + int y; + if (len >= Integer.BYTES) { + x = memorySegment.get(ValueLayout.JAVA_INT_UNALIGNED, start); + y = memorySegment.get(ValueLayout.JAVA_INT_UNALIGNED, start + len - Integer.BYTES); + } + else { + x = memorySegment.get(ValueLayout.JAVA_BYTE, start); + y = memorySegment.get(ValueLayout.JAVA_BYTE, start + len - Byte.BYTES); + } + return (Integer.rotateLeft(x * GOLDEN_RATIO, HASH_LROTATE) ^ y) * GOLDEN_RATIO; } private static int convertIntoNumber(int decimalSepPos, long numberWord) { @@ -405,16 +455,11 @@ public class CalculateAverage_giovannicuccu { return (int) value; } - private static long[] masks = new long[]{ 0x0000000000000000, 0xFF00000000000000L, 0xFFFF000000000000L, - 0xFFFFFF0000000000L, 0xFFFFFFFF00000000L, 0xFFFFFFFFFF000000L, 0xFFFFFFFFFF0000L, 0xFFFFFFFFFFFF00L }; - } public static void main(String[] args) throws IOException { - long start = System.currentTimeMillis(); - MMapReader reader = new MMapReader(Paths.get(FILE), new ProcessorPartitionCalculator(), false); - Map measurements = reader.elaborate(); - // System.out.println("ela=" + (System.currentTimeMillis() - start)); + MMapReaderMemorySegment reader = new MMapReaderMemorySegment(Paths.get(FILE), new ProcessorPartitionCalculator(), false); + Map measurements = reader.elaborate(); System.out.println(measurements); } diff --git a/src/main/java/dev/morling/onebrc/CreateMeasurements3.java b/src/main/java/dev/morling/onebrc/CreateMeasurements3.java index 804b83c..9bcc16d 100644 --- a/src/main/java/dev/morling/onebrc/CreateMeasurements3.java +++ b/src/main/java/dev/morling/onebrc/CreateMeasurements3.java @@ -55,7 +55,7 @@ public class CreateMeasurements3 { out.write(station.name); out.write(';'); out.write(Double.toString(Math.round(temp * 10.0) / 10.0)); - out.newLine(); + out.write('\n'); if (i % 50_000_000 == 0) { System.out.printf("Wrote %,d measurements in %,d ms%n", i, System.currentTimeMillis() - start); }