diff --git a/calculate_average_giovannicuccu.sh b/calculate_average_giovannicuccu.sh old mode 100755 new mode 100644 index 314b5d8..2188385 --- a/calculate_average_giovannicuccu.sh +++ b/calculate_average_giovannicuccu.sh @@ -15,5 +15,5 @@ # limitations under the License. # -JAVA_OPTS="" +JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector -XX:-TieredCompilation" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_giovannicuccu diff --git a/github_users.txt b/github_users.txt index 497909c..eb3ac2c 100644 --- a/github_users.txt +++ b/github_users.txt @@ -1,3 +1,4 @@ +giovannicuccu;Giovanni Cuccu Ujjwalbharti;Ujjwal Bharti abfrmblr;Abhilash ags313;ags diff --git a/prepare_giovannicuccu.sh b/prepare_giovannicuccu.sh old mode 100755 new mode 100644 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java b/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java index 7b549dc..7123c2c 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java @@ -15,10 +15,19 @@ */ package dev.morling.onebrc; +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; + import static java.util.stream.Collectors.*; import java.io.IOException; import java.io.RandomAccessFile; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; @@ -31,34 +40,42 @@ import java.util.*; import java.util.concurrent.*; /* - Solution without unsafe that borrows the ideas of splullara, thomasvue, royvanrijn + Solution without unsafe that borrows the ideas of splullara, thomasvue, royvanrijn and merykitty */ public class CalculateAverage_giovannicuccu { private static final String FILE = "./measurements.txt"; - public static record PartitionBoundary(long start, long end) { + private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_256; + private static final int BYTE_SPECIES_LANES = BYTE_SPECIES.length(); + private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder(); + public static final VectorSpecies INT_SPECIES = IntVector.SPECIES_256; + public static final int INT_SPECIES_LANES = INT_SPECIES.length(); + + public static final int KEY_SIZE = 128; + + public static record PartitionBoundary(Path path, long start, long end) { } public static interface PartitionCalculator { - PartitionBoundary[] computePartitionsBoundaries(Path path); + List computePartitionsBoundaries(Path path); } public static class ProcessorPartitionCalculator implements PartitionCalculator { - public PartitionBoundary[] computePartitionsBoundaries(Path path) { + public List computePartitionsBoundaries(Path path) { try { int numberOfSegments = Runtime.getRuntime().availableProcessors(); long fileSize = path.toFile().length(); long segmentSize = fileSize / numberOfSegments; - PartitionBoundary[] segmentBoundaries = new PartitionBoundary[numberOfSegments]; + List segmentBoundaries = new ArrayList<>(numberOfSegments); try (RandomAccessFile randomAccessFile = new RandomAccessFile(path.toFile(), "r")) { long segStart = 0; long segEnd = segmentSize; for (int i = 0; i < numberOfSegments; i++) { segEnd = findEndSegment(randomAccessFile, segEnd, fileSize); - segmentBoundaries[i] = new PartitionBoundary(segStart, segEnd); + segmentBoundaries.add(new PartitionBoundary(path, segStart, segEnd)); segStart = segEnd; segEnd = Math.min(segEnd + segmentSize, fileSize); } @@ -81,51 +98,27 @@ public class CalculateAverage_giovannicuccu { } } - public static class MeasurementAggregator { - private final int hash; + private static class MeasurementAggregatorVectorized { + private int min; private int max; private double sum; private long count; - private final byte[] station; + private final int len; + private final int hash; + private final int offset; - private final String name; + private byte[] data; - private final long[] data; - private final int dataOffset; - - public MeasurementAggregator(byte[] station, int offset, int hash, int initialValue, long[] data, int dataOffset) { + public MeasurementAggregatorVectorized(byte[] data, int offset, int len, int hash, int initialValue) { min = initialValue; max = initialValue; sum = initialValue; count = 1; - this.station = station; - this.offset = offset; + this.len = len; this.hash = hash; + this.offset = offset; this.data = data; - this.dataOffset = dataOffset; - this.name = new String(station, 0, offset, StandardCharsets.UTF_8); - } - - public MeasurementAggregator(byte[] station, int offset, int hash, int initialValue) { - min = initialValue; - max = initialValue; - sum = initialValue; - count = 1; - this.station = station; - this.offset = offset; - this.hash = hash; - this.data = new long[0]; - this.dataOffset = 0; - this.name = new String(station, 0, offset, StandardCharsets.UTF_8); - } - - public boolean hasSameStation(byte[] stationIn, int offsetIn) { - return Arrays.equals(stationIn, 0, offsetIn, station, 0, offset); - } - - public boolean hasSameStation(long[] dataIn, int offsetIn) { - return Arrays.equals(dataIn, 0, offsetIn, data, 0, dataOffset); } public void add(int value) { @@ -139,8 +132,7 @@ public class CalculateAverage_giovannicuccu { count++; } - public void merge(MeasurementAggregator other) { - // System.out.println("min=" +min + " other min=" +other.min); + public void merge(MeasurementAggregatorVectorized other) { min = Math.min(min, other.min); max = Math.max(max, other.max); sum += other.sum; @@ -149,7 +141,7 @@ public class CalculateAverage_giovannicuccu { @Override public String toString() { - return round((double) min / 10) + "/" + round((sum / (double) count) / 10) + "/" + round((double) max / 10); + return round(min / 10.) + "/" + round(sum / (double) (10 * count)) + "/" + round(max / 10.); } private double round(double value) { @@ -164,116 +156,141 @@ public class CalculateAverage_giovannicuccu { return hash; } - public String getName() { - return name; + public int getLen() { + return len; } - public byte[] getStation() { - return station; + public boolean dataEquals(byte[] data, int offset) { + return Arrays.equals(this.data, this.offset, this.offset + len, data, offset, offset + len); + + } + + public String getName() { + return new String(data, offset, len, StandardCharsets.UTF_8); } public int getOffset() { return offset; } - public long[] getData() { + public byte[] getData() { return data; } - } - public static class MeasurementList { - + private static class MeasurementListVectorized { private static final int SIZE = 1024 * 64; - private final MeasurementAggregator[] measurements = new MeasurementAggregator[SIZE]; + private final MeasurementAggregatorVectorized[] measurements = new MeasurementAggregatorVectorized[SIZE]; + private final byte[] keyData = new byte[SIZE * KEY_SIZE]; - public void add(byte[] station, int offset, int hash, int value) { + private final MemorySegment dataSegment = MemorySegment.ofArray(keyData); + + public void addWithByteVector(ByteVector chunk1, int len, int hash, int value, MemorySegment memorySegment, long offset) { int index = hash & (SIZE - 1); - if (measurements[index] == null) { - measurements[index] = new MeasurementAggregator(station.clone(), offset, hash, value); - } - else { - if (measurements[index].hasSameStation(station, offset)) { - measurements[index].add(value); - } - else { - while (measurements[index] != null && !measurements[index].hasSameStation(station, offset)) { - index = (index + 1) & (SIZE - 1); - } - if (measurements[index] == null) { - measurements[index] = new MeasurementAggregator(station.clone(), offset, hash, value); - } - else { + int i = 0; + while (measurements[index] != null) { + if (measurements[index].getLen() == len && measurements[index].getHash() == hash) { + var nodeKey = ByteVector.fromArray(BYTE_SPECIES, keyData, index * KEY_SIZE); + long eqMask = chunk1.compare(VectorOperators.EQ, nodeKey).toLong(); + long validMask = -1L >>> (64 - len); + if ((eqMask & validMask) == validMask) { measurements[index].add(value); + return; } } + index = (index + 1) & (SIZE - 1); } + MemorySegment.copy(memorySegment, offset, dataSegment, (long) index * KEY_SIZE, len); + measurements[index] = new MeasurementAggregatorVectorized(keyData, index * KEY_SIZE, len, hash, value); } - public void merge(MeasurementAggregator measurementAggregator) { - int index = (measurementAggregator.getHash() & (SIZE - 1)); - if (measurements[index] == null) { - measurements[index] = measurementAggregator; - } - else { - while (measurements[index] != null && !measurements[index].hasSameStation(measurementAggregator.getStation(), measurementAggregator.getOffset())) { - index = (index + 1) & (SIZE - 1); - } - if (measurements[index] == null) { - measurements[index] = measurementAggregator; - } - else { - measurements[index].merge(measurementAggregator); - } + public void add(int len, int hash, int value, MemorySegment memorySegment, long offset) { + int index = hash & (SIZE - 1); + while (measurements[index] != null) { + if (measurements[index].getLen() == len && measurements[index].getHash() == hash) { + int i = 0; + while (i < len && keyData[index * KEY_SIZE + i] == memorySegment.get(ValueLayout.JAVA_BYTE, offset + i)) { + i++; + } + if (i == len) { + measurements[index].add(value); + return; + } + } + index = (index + 1) & (SIZE - 1); } + MemorySegment.copy(memorySegment, offset, dataSegment, (long) index * KEY_SIZE, len); + measurements[index] = new MeasurementAggregatorVectorized(keyData, index * KEY_SIZE, len, hash, value); } - public MeasurementAggregator[] getMeasurements() { + public void merge(MeasurementAggregatorVectorized measurementAggregator) { + int index = measurementAggregator.getHash() & (SIZE - 1); + while (measurements[index] != null) { + if (measurements[index].getLen() == measurementAggregator.getLen() && measurements[index].getHash() == measurementAggregator.getHash()) { + if (measurementAggregator.dataEquals(measurements[index].getData(), measurements[index].getOffset())) { + measurements[index].merge(measurementAggregator); + return; + } + } + index = (index + 1) & (SIZE - 1); + } + measurements[index] = measurementAggregator; + } + + public MeasurementAggregatorVectorized[] getMeasurements() { return measurements; } + } - public static class MMapReader { + private static class MMapReaderMemorySegment { + private final Path path; - private final PartitionBoundary[] boundaries; - + private final List boundaries; private final boolean serial; + private static final byte SEPARATOR = ';'; + ByteVector separators = ByteVector.broadcast(BYTE_SPECIES, SEPARATOR); + private static final ValueLayout.OfLong JAVA_LONG_LT = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); - public MMapReader(Path path, PartitionCalculator partitionCalculator, boolean serial) { + public MMapReaderMemorySegment(Path path, PartitionCalculator partitionCalculator, boolean serial) { this.path = path; this.serial = serial; boundaries = partitionCalculator.computePartitionsBoundaries(path); } - public TreeMap elaborate() { - try (ExecutorService executor = Executors.newFixedThreadPool(boundaries.length)) { - List> futures = new ArrayList<>(); + public TreeMap elaborate() throws IOException { + try (ExecutorService executor = Executors.newFixedThreadPool(boundaries.size()); + FileChannel fileChannel = (FileChannel) Files.newByteChannel((path), StandardOpenOption.READ); + var arena = Arena.ofShared()) { + + List> futures = new ArrayList<>(); for (PartitionBoundary boundary : boundaries) { if (serial) { - FutureTask future = new FutureTask<>(() -> computeListForPartition(boundary.start(), boundary.end())); + FutureTask future = new FutureTask<>(() -> computeListForPartition( + fileChannel, boundary)); future.run(); - // System.out.println("done with partition " + boundary); futures.add(future); } else { - Future future = executor.submit(() -> computeListForPartition(boundary.start(), boundary.end())); + Future future = executor.submit(() -> computeListForPartition( + fileChannel, boundary)); futures.add(future); } } - TreeMap ris = reduce(futures); + TreeMap ris = reduce(futures); return ris; } } - private TreeMap reduce(List> futures) { + private TreeMap reduce(List> futures) { try { - TreeMap risMap = new TreeMap<>(); - MeasurementList ris = new MeasurementList(); - for (Future future : futures) { - MeasurementList results = future.get(); + TreeMap risMap = new TreeMap<>(); + MeasurementListVectorized ris = new MeasurementListVectorized(); + for (Future future : futures) { + MeasurementListVectorized results = future.get(); merge(ris, results); } - for (MeasurementAggregator m : ris.getMeasurements()) { + for (MeasurementAggregatorVectorized m : ris.getMeasurements()) { if (m != null) { risMap.put(m.getName(), m); } @@ -286,101 +303,134 @@ public class CalculateAverage_giovannicuccu { } } - private void merge(MeasurementList result, MeasurementList partial) { - for (MeasurementAggregator m : partial.getMeasurements()) { + private void merge(MeasurementListVectorized result, MeasurementListVectorized partial) { + for (MeasurementAggregatorVectorized m : partial.getMeasurements()) { if (m != null) { result.merge(m); } } } - private MeasurementList computeListForPartition(long start, long end) { - MeasurementList list = new MeasurementList(); - try { - try (FileChannel fileChannel = (FileChannel) Files.newByteChannel((path), StandardOpenOption.READ)) { - MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, start, end - start); - mappedByteBuffer.order(BYTE_ORDER.LITTLE_ENDIAN); - int limit = mappedByteBuffer.limit(); - int startLine; - byte[] stationb = new byte[100]; - while ((startLine = mappedByteBuffer.position()) < limit - 110) { - int currentPosition = startLine; - byte b = 0; - int i = 0; - int hash = 0; - - while ((b = mappedByteBuffer.get(currentPosition++)) != ';') { - stationb[i++] = b; - hash = 31 * hash + b; + private MeasurementListVectorized computeListForPartition(FileChannel fileChannel, PartitionBoundary boundary) { + try (var arena = Arena.ofConfined()) { + var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, boundary.start(), boundary.end() - boundary.start(), arena); + MeasurementListVectorized list = new MeasurementListVectorized(); + long size = memorySegment.byteSize(); + long offset = 0; + long safe = size - KEY_SIZE; + // ByteBuffer byteBuffer = memorySegment.asByteBuffer(); + // byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + ByteVector chunk1 = ByteVector.zero(BYTE_SPECIES); + ByteVector chunk2 = ByteVector.zero(BYTE_SPECIES); + while (offset < safe) { + int len = 0; + chunk1 = ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment, offset, NATIVE_ORDER); + int equals = chunk1.compare(VectorOperators.EQ, separators).firstTrue(); + len += equals; + if (equals == BYTE_SPECIES_LANES) { + while (memorySegment.get(ValueLayout.JAVA_BYTE, offset + len) != ';') { + len++; } - if (hash < 0) { - hash = -hash; - } - - long numberWord = mappedByteBuffer.getLong(currentPosition); - int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); - int value = convertIntoNumber(decimalSepPos, numberWord); - mappedByteBuffer.position(currentPosition + (decimalSepPos >>> 3) + 3); - - list.add(stationb, i, hash, value); - } - while ((startLine = mappedByteBuffer.position()) < limit) { - int currentPosition = startLine; - byte b = 0; - int i = 0; - int hash = 0; - while ((b = mappedByteBuffer.get(currentPosition++)) != ';') { - stationb[i++] = b; - hash = 31 * hash + b; - } - if (hash < 0) { - hash = -hash; - } - int value = 0; - if (currentPosition <= limit - 8) { - long numberWord = mappedByteBuffer.getLong(currentPosition); - int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); - value = convertIntoNumber(decimalSepPos, numberWord); - mappedByteBuffer.position(currentPosition + (decimalSepPos >>> 3) + 3); - } - else { - int sign = 1; - b = mappedByteBuffer.get(currentPosition++); - if (b == '-') { - sign = -1; - } - else { - value = b - '0'; - } - while ((b = mappedByteBuffer.get(currentPosition++)) != '.') { - value = value * 10 + (b - '0'); - } - b = mappedByteBuffer.get(currentPosition); - value = value * 10 + (b - '0'); - if (sign == -1) { - value = -value; - } - mappedByteBuffer.position(currentPosition + 2); - } + int hash = hash(memorySegment, offset, len); + long prevOffset = offset; + offset += len + 1; - list.add(stationb, i, hash, value); + long numberWord = memorySegment.get(JAVA_LONG_LT, offset); + int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); + int value = convertIntoNumber(decimalSepPos, numberWord); + offset += (decimalSepPos >>> 3) + 3; + // System.out.println("Value=" + value); + if (len < BYTE_SPECIES_LANES) { + list.addWithByteVector(chunk1, len, hash, value, memorySegment, prevOffset); + } + else { + list.add(len, hash, value, memorySegment, prevOffset); } } + + while (offset < size) { + int len = 0; + int equals = BYTE_SPECIES_LANES; + if (offset + BYTE_SPECIES_LANES < size) { + chunk1 = ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment, offset, NATIVE_ORDER); + equals = chunk1.compare(VectorOperators.EQ, separators).firstTrue(); + len += equals; + if (equals == BYTE_SPECIES_LANES) { + while (memorySegment.get(ValueLayout.JAVA_BYTE, offset + len) != ';') { + len++; + } + } + } + else { + byte[] bytes = new byte[BYTE_SPECIES_LANES]; + MemorySegment.copy(memorySegment, offset + len, MemorySegment.ofArray(bytes), 0, (size - offset - len)); + // byteBuffer.get(offset + len, bytes, 0, (int) (size - offset - len)); + chunk1 = ByteVector.fromArray(BYTE_SPECIES, bytes, 0); + equals = chunk1.compare(VectorOperators.EQ, separators).firstTrue(); + len += equals; + } + int hash = hash(memorySegment, offset, len); + long prevOffset = offset; + offset += len + 1; + + int value = 0; + if (offset < size - 8) { + long numberWord = memorySegment.get(JAVA_LONG_LT, offset); + int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); + value = convertIntoNumber(decimalSepPos, numberWord); + offset += (decimalSepPos >>> 3) + 3; + } + else { + long currentPosition = offset; + int sign = 1; + byte b = memorySegment.get(ValueLayout.JAVA_BYTE, currentPosition++); + if (b == '-') { + sign = -1; + } + else { + value = b - '0'; + } + while ((b = memorySegment.get(ValueLayout.JAVA_BYTE, currentPosition++)) != '.') { + value = value * 10 + (b - '0'); + } + b = memorySegment.get(ValueLayout.JAVA_BYTE, currentPosition); + value = value * 10 + (b - '0'); + if (sign == -1) { + value = -value; + } + offset = currentPosition + 2; + } + if (len < BYTE_SPECIES_LANES) { + list.addWithByteVector(chunk1, len, hash, value, memorySegment, prevOffset); + } + else { + list.add(len, hash, value, memorySegment, prevOffset); + } + } + return list; } catch (IOException e) { - System.out.println("Error"); - System.err.println(e); + throw new RuntimeException(e); } - return list; } - private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder(); + private static final int GOLDEN_RATIO = 0x9E3779B9; + private static final int HASH_LROTATE = 5; - private static long getLongLittleEndian(long value) { - value = Long.reverseBytes(value); - return value; + private static int hash(MemorySegment memorySegment, long start, int len) { + int x; + int y; + if (len >= Integer.BYTES) { + x = memorySegment.get(ValueLayout.JAVA_INT_UNALIGNED, start); + y = memorySegment.get(ValueLayout.JAVA_INT_UNALIGNED, start + len - Integer.BYTES); + } + else { + x = memorySegment.get(ValueLayout.JAVA_BYTE, start); + y = memorySegment.get(ValueLayout.JAVA_BYTE, start + len - Byte.BYTES); + } + return (Integer.rotateLeft(x * GOLDEN_RATIO, HASH_LROTATE) ^ y) * GOLDEN_RATIO; } private static int convertIntoNumber(int decimalSepPos, long numberWord) { @@ -405,16 +455,11 @@ public class CalculateAverage_giovannicuccu { return (int) value; } - private static long[] masks = new long[]{ 0x0000000000000000, 0xFF00000000000000L, 0xFFFF000000000000L, - 0xFFFFFF0000000000L, 0xFFFFFFFF00000000L, 0xFFFFFFFFFF000000L, 0xFFFFFFFFFF0000L, 0xFFFFFFFFFFFF00L }; - } public static void main(String[] args) throws IOException { - long start = System.currentTimeMillis(); - MMapReader reader = new MMapReader(Paths.get(FILE), new ProcessorPartitionCalculator(), false); - Map measurements = reader.elaborate(); - // System.out.println("ela=" + (System.currentTimeMillis() - start)); + MMapReaderMemorySegment reader = new MMapReaderMemorySegment(Paths.get(FILE), new ProcessorPartitionCalculator(), false); + Map measurements = reader.elaborate(); System.out.println(measurements); } diff --git a/src/main/java/dev/morling/onebrc/CreateMeasurements3.java b/src/main/java/dev/morling/onebrc/CreateMeasurements3.java index 804b83c..9bcc16d 100644 --- a/src/main/java/dev/morling/onebrc/CreateMeasurements3.java +++ b/src/main/java/dev/morling/onebrc/CreateMeasurements3.java @@ -55,7 +55,7 @@ public class CreateMeasurements3 { out.write(station.name); out.write(';'); out.write(Double.toString(Math.round(temp * 10.0) / 10.0)); - out.newLine(); + out.write('\n'); if (i % 50_000_000 == 0) { System.out.printf("Wrote %,d measurements in %,d ms%n", i, System.currentTimeMillis() - start); }