From f37b304fc3006933dd1245218c41404c336df280 Mon Sep 17 00:00:00 2001 From: Parth Mudgal Date: Fri, 12 Jan 2024 14:08:09 +0530 Subject: [PATCH] inline hash calculation and number parsing (#200) no number parsing with precalculated map verify tests better loop with direct hash to measurement mapping accept formatting changes Use unsafe --- calculate_average_artpar.sh | 2 +- .../onebrc/CalculateAverage_artpar.java | 416 +++++++++--------- 2 files changed, 212 insertions(+), 206 deletions(-) diff --git a/calculate_average_artpar.sh b/calculate_average_artpar.sh index 56c9b25..7dfda89 100755 --- a/calculate_average_artpar.sh +++ b/calculate_average_artpar.sh @@ -16,5 +16,5 @@ # -JAVA_OPTS="--add-modules=jdk.incubator.vector" +JAVA_OPTS="--enable-preview" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_artpar diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artpar.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artpar.java index 835e65e..4faf322 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_artpar.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artpar.java @@ -15,11 +15,16 @@ */ package dev.morling.onebrc; +import sun.misc.Unsafe; + import java.io.BufferedOutputStream; import java.io.IOException; import java.io.PrintStream; import java.io.RandomAccessFile; -import java.nio.MappedByteBuffer; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.reflect.Field; +import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Files; @@ -37,17 +42,22 @@ import java.util.stream.Collectors; public class CalculateAverage_artpar { public static final int N_THREADS = 8; private static final String FILE = "./measurements.txt"; + private static final int INT_MAP_SIZE = 8192; // from calculateIntegerByteMapTest() + final static int[] byteHashMapToInt = calculateIntegerByteMap(); + private static final Unsafe UNSAFE = initUnsafe(); // private static final VectorSpecies SPECIES = IntVector.SPECIES_PREFERRED; // final int VECTOR_SIZE = 512; // final int VECTOR_SIZE_1 = VECTOR_SIZE - 1; - final int SIZE = 1024 * 1024; + final int AVERAGE_CHUNK_SIZE = 1024 * 64; + final int AVERAGE_CHUNK_SIZE_1 = AVERAGE_CHUNK_SIZE - 1; public CalculateAverage_artpar() throws IOException { long start = Instant.now().toEpochMilli(); Path measurementFile = Paths.get(FILE); long fileSize = Files.size(measurementFile); - long expectedChunkSize = Math.max(fileSize / 8, 1024); + // System.out.println("File size - " + fileSize); + int expectedChunkSize = Math.toIntExact(Math.min(fileSize / N_THREADS, Integer.MAX_VALUE / 2)); ExecutorService threadPool = Executors.newFixedThreadPool(N_THREADS); @@ -56,52 +66,50 @@ public class CalculateAverage_artpar { List>> futures = new ArrayList<>(); long bytesReadCurrent = 0; - try (FileChannel fileChannel = FileChannel.open(measurementFile, StandardOpenOption.READ)) { - for (int i = 0; i < 8; i++) { + FileChannel fileChannel = FileChannel.open(measurementFile, StandardOpenOption.READ); + for (int i = 0; chunkStartPosition < fileSize; i++) { - long chunkSize = expectedChunkSize; - chunkSize = fis.skipBytes(Math.toIntExact(chunkSize)); + int chunkSize = expectedChunkSize; + chunkSize = fis.skipBytes(chunkSize); - bytesReadCurrent += chunkSize; - while (((char) fis.read()) != '\n' && bytesReadCurrent < fileSize) { - chunkSize++; - bytesReadCurrent++; - } - - // System.out.println("[" + chunkStartPosition + "] - [" + (chunkStartPosition + chunkSize) + " bytes"); - if (chunkStartPosition + chunkSize >= fileSize) { - chunkSize = fileSize - chunkStartPosition; - } - if (chunkSize < 1) { - break; - } - if (chunkSize > Integer.MAX_VALUE) { - chunkSize = Integer.MAX_VALUE; - } - - MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, chunkStartPosition, - chunkSize); - - ReaderRunnable readerRunnable = new ReaderRunnable(mappedByteBuffer); - Future> future = threadPool.submit(readerRunnable::run); - // System.out.println("Added future [" + chunkStartPosition + "][" + chunkSize + "]"); - futures.add(future); - chunkStartPosition = chunkStartPosition + chunkSize + 1; + bytesReadCurrent += chunkSize; + while (((char) fis.read()) != '\n' && bytesReadCurrent < fileSize) { + chunkSize++; + bytesReadCurrent++; } + + // System.out.println("[" + chunkStartPosition + "] - [" + (chunkStartPosition + chunkSize) + " bytes"); + if (chunkStartPosition + chunkSize >= fileSize) { + chunkSize = (int) Math.min(fileSize - chunkStartPosition, Integer.MAX_VALUE); + } + if (chunkSize < 1) { + break; + } + if (chunkSize >= Integer.MAX_VALUE) { + throw new RuntimeException(); + } + + // MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, chunkStartPosition, + // chunkSize); + + ReaderRunnable readerRunnable = new ReaderRunnable(chunkStartPosition, chunkSize, fileChannel); + Future> future = threadPool.submit(readerRunnable::run); + // System.out.println("Added future [" + chunkStartPosition + "][" + chunkSize + "]"); + futures.add(future); + chunkStartPosition = chunkStartPosition + chunkSize + 1; } + fis.close(); - Map globalMap = futures.parallelStream() - .flatMap(future -> { - try { - return future.get().entrySet().stream(); - } - catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }).parallel().collect(Collectors.toMap( - Map.Entry::getKey, Map.Entry::getValue, - MeasurementAggregator::combine)); + Map globalMap = futures.parallelStream().flatMap(future -> { + try { + return future.get().entrySet().stream(); + } + catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }).parallel().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, MeasurementAggregator::combine)); + fileChannel.close(); Map results = globalMap.entrySet().stream().parallel() .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().finish())); @@ -136,33 +144,99 @@ public class CalculateAverage_artpar { } + public static int[] calculateIntegerByteMapTest() { + int[] intToIntMap = null; + for (int j = 0; j < 10000; j++) { + int length = 2000 + j; + intToIntMap = new int[length]; + boolean hasHashClash = false; + Map byteHashToInt = new HashMap<>(); + for (int i = -999; i < 1000; i++) { + int hashCode = hashInteger(i); + + // String s = new String(value); + int position = hashCode & (length - 1); + // System.out.printf("%.1f => %s length [%d] hash [%d] => %d\n", number, s, s.length(), hashCode, position); + if (byteHashToInt.containsKey(hashCode) || intToIntMap[position] != 0) { + // System.err.println("HashClash [" + hashCode + "] -> " + + // byteHashToInt.get( + // hashCode) + " vs " + number + " == [" + position + "] =>" + intToIntMap[position]); + hasHashClash = true; + break; + } + else { + byteHashToInt.put(hashCode, i); + intToIntMap[position] = i; + } + } + if (!hasHashClash) { + // 8192 + System.out.println("NoHash clash at [" + length + "]"); + // throw new RuntimeException("clash"); + return intToIntMap; + } + + } + System.out.println("Fail"); + return null; + } + + private static int hashInteger(int i) { + float number = i / 10f; + String numberString = String.format("%.1f", number); + byte[] value = numberString.getBytes(); + + int hashCode = 1; + for (int k = 0; k < value.length; k++) { + hashCode = hashCode * 31 + value[k]; + } + return hashCode; + } + + public static int[] calculateIntegerByteMap() { + long start = System.currentTimeMillis(); + int[] intToIntMap = new int[INT_MAP_SIZE]; + for (int i = -999; i < 1000; i++) { + float number = i / 10f; + byte[] value = String.format("%.1f", number).getBytes(); + + int hashCode = 1; + for (byte b : value) { + hashCode = hashCode * 31 + b; + } + int position = hashCode & (INT_MAP_SIZE - 1); + intToIntMap[position] = i; + } + long end = System.currentTimeMillis(); + // System.out.println("calculateIntegerByteMap " + (end - start) + " ms"); + return intToIntMap; + } + public static void main(String[] args) throws IOException { new CalculateAverage_artpar(); } - public static int hashCode(byte[] array, int length) { - - int h = 1; - int i = 0; - for (; i + 7 < length; i += 8) { - h = 31 * 31 * 31 * 31 * 31 * 31 * 31 * 31 * h + 31 * 31 * 31 * 31 - * 31 * 31 * 31 * array[i] + 31 * 31 * 31 * 31 * 31 * 31 - * array[i + 1] - + 31 * 31 * 31 * 31 * 31 * array[i + 2] + 31 - * 31 * 31 * 31 * array[i + 3] - + 31 * 31 * 31 * array[i + 4] - + 31 * 31 * array[i + 5] + 31 * array[i + 6] + array[i + 7]; + private static Unsafe initUnsafe() { + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + return (Unsafe) theUnsafe.get(Unsafe.class); } - - for (; i + 3 < length; i += 4) { - h = 31 * 31 * 31 * 31 * h + 31 * 31 * 31 * array[i] + 31 * 31 - * array[i + 1] + 31 * array[i + 2] + array[i + 3]; - } - for (; i < length; i++) { - h = 31 * h + array[i]; + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); } + } - return h; + static boolean unsafeEquals(long aStart, long aLength, long bStart, long bLength) { + if (aLength != bLength) { + return false; + } + for (int i = 0; i < aLength; ++i) { + if (UNSAFE.getByte(aStart + i) != UNSAFE.getByte(bStart + i)) { + return false; + } + } + return true; } private record ResultRow(double min, double mean, double max) { @@ -180,43 +254,25 @@ public class CalculateAverage_artpar { } private static class MeasurementAggregator { - private double min = Double.POSITIVE_INFINITY; - private double max = Double.NEGATIVE_INFINITY; + private int min = 999; + private int max = -999; private double sum; private long count; - public MeasurementAggregator() { - } - - // public MeasurementAggregator(double min, double max, double sum, long count) { - // this.min = min; - // this.max = max; - // this.sum = sum; - // this.count = count; - // } - MeasurementAggregator combine(MeasurementAggregator other) { - min = Math.min(min, other.min); - max = Math.max(max, other.max); + min = other.min + ((min - other.min) & ((min - other.min) >> (32 * 8 - 1))); + max = max - ((max - other.max) & ((max - other.max) >> (32 * 8 - 1))); sum += other.sum; count += other.count; return this; } - // MeasurementAggregator combine(double otherMin, double otherMax, double otherSum, long otherCount) { - // min = Math.min(min, otherMin); - // max = Math.max(max, otherMax); - // sum += otherSum; - // count += otherCount; - // return this; - // } - - MeasurementAggregator combine(double value) { - min = Math.min(min, value); - max = Math.max(max, value); + void combine(int value) { sum += value; - count += 1; - return this; + count++; + + min = value + ((min - value) & ((min - value) >> (32 * 8 - 1))); // min(x, y) + max = max - ((max - value) & ((max - value) >> (32 * 8 - 1))); // max(x, y) } ResultRow finish() { @@ -227,150 +283,100 @@ public class CalculateAverage_artpar { static class StationName { public final int hash; - private final String name; - // private final int index; + private final ByteBuffer nameBytes; + private final MeasurementAggregator measurementAggregator = new MeasurementAggregator(); public int count = 0; - // public int[] values = new int[VECTOR_SIZE]; - public MeasurementAggregator measurementAggregator = new MeasurementAggregator(); - public StationName(String name, int hash) { - this.name = name; - // this.index = index; + public StationName(ByteBuffer nameBytes, int hash) { + this.nameBytes = nameBytes; this.hash = hash; } } private class ReaderRunnable { - private final MappedByteBuffer mappedByteBuffer; + private final long startPosition; + private final FileChannel fileChannel; + private final int chunkSize; StationNameMap stationNameMap = new StationNameMap(); - // double[][] stationValueMap = new double[SIZE][]; - private ReaderRunnable(MappedByteBuffer mappedByteBuffer) { - this.mappedByteBuffer = mappedByteBuffer; + private ReaderRunnable(long startPosition, int chunkSize, FileChannel fileChannel) throws IOException { + this.chunkSize = chunkSize; + this.startPosition = startPosition; + this.fileChannel = fileChannel; + // mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startPosition, chunkSize); } - public Map run() { - // System.out.println("Started future - " + mappedByteBuffer.position()); + public Map run() throws IOException { + MemorySegment mappedSegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, + startPosition, chunkSize, Arena.global()); - int doubleValue; - long start = Date.from(Instant.now()).getTime(); - // int totalBytesRead = 0; - - // ByteBuffer nameBuffer = ByteBuffer.allocate(128); - int MAPPED_BYTE_BUFFER_SIZE = 8192; - byte[] rawBuffer = new byte[32]; + long rawBufferAddress = UNSAFE.allocateMemory(100); int rawBufferReadIndex = 0; - StationName matchedStation = null; - boolean expectedName = true; + long position = mappedSegment.address(); + long endPosition = position + chunkSize; + byte b; + int hash; + int nameHash; - byte[] mappedBytes = new byte[MAPPED_BYTE_BUFFER_SIZE]; - int mappedBytesReadIndex; - boolean negative = false; - int start1 = 0; - int result = 0; + hash = 1; - while (mappedByteBuffer.hasRemaining()) { - int remaining = mappedByteBuffer.remaining(); - int bytesToRead = Math.min(remaining, MAPPED_BYTE_BUFFER_SIZE); - mappedByteBuffer.get(mappedBytes, 0, bytesToRead); - remaining = mappedByteBuffer.remaining(); - mappedBytesReadIndex = 0; - - while (mappedBytesReadIndex < bytesToRead) { - byte b = mappedBytes[mappedBytesReadIndex]; - mappedBytesReadIndex++; - - if (expectedName) { - if (b != ';') { - rawBuffer[rawBufferReadIndex] = b; - rawBufferReadIndex++; - continue; - } - else { - expectedName = false; - matchedStation = stationNameMap.getOrCreate(rawBuffer, rawBufferReadIndex); - rawBufferReadIndex = 0; - negative = false; - start1 = 0; - result = 0; - continue; - } - } - - while (b != '\n') { - rawBuffer[rawBufferReadIndex] = b; - rawBufferReadIndex++; - - if (mappedBytesReadIndex < bytesToRead) { - b = mappedBytes[mappedBytesReadIndex]; - mappedBytesReadIndex++; - } - else { - break; - } - } - - if (b != '\n') { - if (mappedBytesReadIndex == bytesToRead && remaining > 0) { - continue; - } - } - - // Check for negative numbers - if (rawBuffer[0] == '-') { - negative = true; - start1++; - } - - for (int i = start1; i < rawBufferReadIndex; i++) { - byte c = rawBuffer[i]; - if (c != '.') { - result = result * 10 + (c - '0'); - } - } - - doubleValue = negative ? -result : result; - rawBufferReadIndex = 0; - matchedStation.measurementAggregator.combine(doubleValue); - matchedStation.count++; - expectedName = true; + while (position < endPosition) { + while ((position < endPosition) && + (b = UNSAFE.getByte(position++)) != ';') { + UNSAFE.putByte(rawBufferAddress + rawBufferReadIndex++, b); + hash = hash * 31 + b; } + nameHash = hash; + hash = 1; + + while ((position < endPosition) && + (b = UNSAFE.getByte(position++)) != '\n') { + hash = hash * 31 + b; + } + stationNameMap.getOrCreate(rawBufferAddress, rawBufferReadIndex, + byteHashMapToInt[hash & (INT_MAP_SIZE - 1)], nameHash); + rawBufferReadIndex = 0; + hash = 1; + } - - long end = Date.from(Instant.now()).getTime(); - // System.out.println("Took [" + ((end - start) / 1000) + "s for " + totalBytesRead / 1024 + " kb"); - - return Arrays.stream(stationNameMap.names).parallel().filter(Objects::nonNull) - .collect(Collectors.toMap(e -> e.name, e -> e.measurementAggregator)); - // return groupedMeasurements; + return Arrays.stream(stationNameMap.names).parallel().filter(Objects::nonNull).collect( + Collectors.toMap(e -> StandardCharsets.UTF_8.decode(e.nameBytes).toString(), + e -> e.measurementAggregator, MeasurementAggregator::combine)); } } class StationNameMap { - int[] indexes = new int[SIZE]; - StationName[] names = new StationName[SIZE]; + int[] indexes = new int[AVERAGE_CHUNK_SIZE]; + StationName[] names = new StationName[AVERAGE_CHUNK_SIZE]; int currentIndex = 0; + ByteBuffer bytesForName = ByteBuffer.allocateDirect(1000 * 100); + int nameBufferIndex = 0; - public StationName getOrCreate(byte[] stationNameBytes, int length) { - - int hash = CalculateAverage_artpar.hashCode(stationNameBytes, length); - - int position = Math.abs(hash) % SIZE; - while (indexes[position] != 0 && names[indexes[position]].hash != hash) { - position = ++position % SIZE; + public void getOrCreate(long stationNameBytesAddress, int length, int doubleValue, int hash) { + int position = hash & AVERAGE_CHUNK_SIZE_1; + while (indexes[position] != 0 && (names[indexes[position]].hash != hash)) { + position = ++position & AVERAGE_CHUNK_SIZE_1; } if (indexes[position] != 0) { - return names[indexes[position]]; + StationName stationName = names[indexes[position]]; + stationName.measurementAggregator.combine(doubleValue); + } + else { + ByteBuffer nameSlice = bytesForName.slice(nameBufferIndex, length); + nameBufferIndex += length; + for (int i = 0; i < length; i++) { + nameSlice.put(UNSAFE.getByte(stationNameBytesAddress + i)); + } + nameSlice.flip(); + StationName stationName = new StationName(nameSlice, hash); + indexes[position] = ++currentIndex; + names[indexes[position]] = stationName; + stationName.measurementAggregator.combine(doubleValue); } - StationName stationName = new StationName( - new String(stationNameBytes, 0, length, StandardCharsets.UTF_8), hash); - indexes[position] = ++currentIndex; - names[indexes[position]] = stationName; - return stationName; } } -} +} \ No newline at end of file