From 06f9b748297d5929e6cc5f913d10a04cf0501381 Mon Sep 17 00:00:00 2001 From: zerninv Date: Fri, 12 Jan 2024 08:54:28 +0000 Subject: [PATCH] use unsafe (#343) --- calculate_average_zerninv.sh | 2 +- .../onebrc/CalculateAverage_zerninv.java | 181 +++++++++++------- 2 files changed, 110 insertions(+), 73 deletions(-) diff --git a/calculate_average_zerninv.sh b/calculate_average_zerninv.sh index 1cc4197..2b76c7d 100755 --- a/calculate_average_zerninv.sh +++ b/calculate_average_zerninv.sh @@ -15,5 +15,5 @@ # limitations under the License. # -JAVA_OPTS="" +JAVA_OPTS="--enable-preview" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_zerninv \ No newline at end of file diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java index 4e6b255..0ca1141 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java @@ -15,9 +15,14 @@ */ package dev.morling.onebrc; +import sun.misc.Unsafe; + +import java.io.BufferedOutputStream; import java.io.IOException; -import java.nio.MappedByteBuffer; +import java.lang.foreign.Arena; +import java.lang.reflect.Field; import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.*; @@ -27,28 +32,33 @@ import java.util.concurrent.Future; public class CalculateAverage_zerninv { private static final String FILE = "./measurements.txt"; - private static final int MIN_FILE_SIZE = 1024 * 1024; + private static final int MIN_CHUNK_SIZE = 1024 * 1024 * 16; private static final char DELIMITER = ';'; private static final char LINE_SEPARATOR = '\n'; private static final char ZERO = '0'; private static final char NINE = '9'; private static final char MINUS = '-'; + private static final Unsafe UNSAFE = initUnsafe(); + public static void main(String[] args) throws IOException { var results = new HashMap(); try (var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { var fileSize = channel.size(); + var memorySegment = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()); + long address = memorySegment.address(); var cores = Runtime.getRuntime().availableProcessors(); - var chunks = cores - 1; - var maxChunkSize = fileSize < MIN_FILE_SIZE ? fileSize : Math.min(fileSize / chunks, Integer.MAX_VALUE); - var chunkOffsets = splitByChunks(channel, maxChunkSize); + var chunkAmount = cores - 1; + // var maxChunkSize = Math.min(fileSize, MIN_CHUNK_SIZE); + var maxChunkSize = fileSize < MIN_CHUNK_SIZE ? fileSize : fileSize / chunkAmount; + var chunks = splitByChunks(address, address + fileSize, maxChunkSize); var executor = Executors.newFixedThreadPool(cores); List>> fResults = new ArrayList<>(); - for (int i = 1; i < chunkOffsets.size(); i++) { - final long prev = chunkOffsets.get(i - 1); - final long curr = chunkOffsets.get(i); - fResults.add(executor.submit(() -> calcForChunk(channel, prev, curr))); + for (int i = 1; i < chunks.size(); i++) { + final long prev = chunks.get(i - 1); + final long curr = chunks.get(i); + fResults.add(executor.submit(() -> calcForChunk(prev, curr))); } fResults.forEach(f -> { @@ -69,49 +79,62 @@ public class CalculateAverage_zerninv { }); executor.shutdown(); } - System.out.println(new TreeMap<>(results)); + + var bos = new BufferedOutputStream(System.out); + bos.write(new TreeMap<>(results).toString().getBytes(StandardCharsets.UTF_8)); + bos.write('\n'); + bos.flush(); } - private static List splitByChunks(FileChannel channel, long maxChunkSize) throws IOException { - long size = channel.size(); + private static Unsafe initUnsafe() { + try { + Field unsafe = Unsafe.class.getDeclaredField("theUnsafe"); + unsafe.setAccessible(true); + return (Unsafe) unsafe.get(Unsafe.class); + } + catch (IllegalAccessException | NoSuchFieldException e) { + throw new RuntimeException(e); + } + } + + private static List splitByChunks(long address, long end, long maxChunkSize) { List result = new ArrayList<>(); - long current = 0; - result.add(current); - while (current < size) { - var mbb = channel.map(FileChannel.MapMode.READ_ONLY, current, Math.min(size - current, maxChunkSize)); - int position = mbb.limit() - 1; - while (mbb.get(position) != LINE_SEPARATOR) { - position--; + result.add(address); + while (address < end) { + long ptr = address + Math.min(end - address, maxChunkSize) - 1; + while (UNSAFE.getByte(ptr) != LINE_SEPARATOR) { + ptr--; } - current += position + 1; - result.add(current); + address = ptr + 1; + result.add(address); } return result; } - private static Map calcForChunk(FileChannel channel, long begin, long end) throws IOException { - var mbb = channel.map(FileChannel.MapMode.READ_ONLY, begin, end - begin); - var results = new MeasurementContainer(mbb); - int cityOffset, cityNameSize, hashCode, temperatureOffset, temperature; - byte b; + private static Map calcForChunk(long offset, long end) { + var results = new MeasurementContainer(); - while (mbb.hasRemaining()) { - cityOffset = mbb.position(); + long cityOffset, temperatureOffset; + int hashCode, temperature; + byte cityNameSize, b; + + while (offset < end) { + cityOffset = offset; hashCode = 0; - while ((b = mbb.get()) != DELIMITER) { + while ((b = UNSAFE.getByte(offset++)) != DELIMITER) { hashCode = 31 * hashCode + b; } - temperatureOffset = mbb.position(); - cityNameSize = temperatureOffset - cityOffset - 1; + temperatureOffset = offset; + cityNameSize = (byte) (temperatureOffset - cityOffset - 1); temperature = 0; - while ((b = mbb.get()) != LINE_SEPARATOR) { + while ((b = UNSAFE.getByte(offset++)) != LINE_SEPARATOR) { if (b >= ZERO && b <= NINE) { temperature = temperature * 10 + (b - ZERO); } } - if (mbb.get(temperatureOffset) == MINUS) { + if (UNSAFE.getByte(temperatureOffset) == MINUS) { temperature *= -1; } results.put(cityOffset, cityNameSize, hashCode, (short) temperature); @@ -121,11 +144,11 @@ public class CalculateAverage_zerninv { private static final class MeasurementAggregation { private long sum; - private int count; + private long count; private short min; private short max; - public MeasurementAggregation(long sum, int count, short min, short max) { + public MeasurementAggregation(long sum, long count, short min, short max) { this.sum = sum; this.count = count; this.min = min; @@ -151,74 +174,88 @@ public class CalculateAverage_zerninv { private static final class MeasurementContainer { private static final int SIZE = 1024 * 16; - private final MappedByteBuffer mbb; - private final int[] offsets = new int[SIZE]; - private final int[] sizes = new int[SIZE]; - private final int[] hashes = new int[SIZE]; + private static final int ENTRY_SIZE = 8 + 1 + 4 + 8 + 8 + 2 + 2; + private static final int COUNT_OFFSET = 0; + private static final int SIZE_OFFSET = 8; + private static final int HASH_OFFSET = 9; + private static final int ADDRESS_OFFSET = 13; + private static final int SUM_OFFSET = 21; + private static final int MIN_OFFSET = 29; + private static final int MAX_OFFSET = 31; - private final long[] sums = new long[SIZE]; - private final int[] counts = new int[SIZE]; - private final short[] mins = new short[SIZE]; - private final short[] maxs = new short[SIZE]; + private final long address; - private MeasurementContainer(MappedByteBuffer mbb) { - this.mbb = mbb; - Arrays.fill(mins, Short.MAX_VALUE); - Arrays.fill(maxs, Short.MIN_VALUE); + private MeasurementContainer() { + address = UNSAFE.allocateMemory(ENTRY_SIZE * SIZE); + UNSAFE.setMemory(address, ENTRY_SIZE * SIZE, (byte) 0); + for (long ptr = address; ptr < address + SIZE * ENTRY_SIZE; ptr += ENTRY_SIZE) { + UNSAFE.putShort(ptr + MIN_OFFSET, Short.MAX_VALUE); + UNSAFE.putShort(ptr + MAX_OFFSET, Short.MIN_VALUE); + } } - public void put(int offset, int size, int hash, short value) { - int i = findIdx(offset, size, hash); - offsets[i] = offset; - sizes[i] = size; - hashes[i] = hash; + public void put(long address, byte size, int hash, short value) { + long ptr = findAddress(address, size, hash); - sums[i] += value; - counts[i]++; + UNSAFE.putLong(ptr + COUNT_OFFSET, UNSAFE.getLong(ptr + COUNT_OFFSET) + 1); + UNSAFE.putByte(ptr + SIZE_OFFSET, size); + UNSAFE.putInt(ptr + HASH_OFFSET, hash); + UNSAFE.putLong(ptr + ADDRESS_OFFSET, address); - if (value < mins[i]) { - mins[i] = value; + UNSAFE.putLong(ptr + SUM_OFFSET, UNSAFE.getLong(ptr + SUM_OFFSET) + value); + if (value < UNSAFE.getShort(ptr + MIN_OFFSET)) { + UNSAFE.putShort(ptr + MIN_OFFSET, value); } - if (value > maxs[i]) { - maxs[i] = value; + if (value > UNSAFE.getShort(ptr + MAX_OFFSET)) { + UNSAFE.putShort(ptr + MAX_OFFSET, value); } } public Map toStringMap() { var result = new HashMap(); for (int i = 0; i < SIZE; i++) { - if (counts[i] != 0) { - var key = createString(offsets[i], sizes[i]); - result.put(key, new MeasurementAggregation(sums[i], counts[i], mins[i], maxs[i])); + long ptr = this.address + i * ENTRY_SIZE; + if (UNSAFE.getLong(ptr + COUNT_OFFSET) != 0) { + var measurements = new MeasurementAggregation( + UNSAFE.getLong(ptr + SUM_OFFSET), + UNSAFE.getLong(ptr + COUNT_OFFSET), + UNSAFE.getShort(ptr + MIN_OFFSET), + UNSAFE.getShort(ptr + MAX_OFFSET)); + var key = createString(UNSAFE.getLong(ptr + ADDRESS_OFFSET), UNSAFE.getByte(ptr + SIZE_OFFSET)); + result.put(key, measurements); } } return result; } - private int findIdx(int offset, int size, int hash) { - int i = Math.abs(hash % SIZE); - while (counts[i] != 0) { - if (hashes[i] == hash && sizes[i] == size && isEqual(i, offset)) { + private long findAddress(long address, byte size, int hash) { + int idx = Math.abs(hash % SIZE); + long ptr = this.address + idx * ENTRY_SIZE; + while (UNSAFE.getLong(ptr + COUNT_OFFSET) != 0) { + if (UNSAFE.getByte(ptr + SIZE_OFFSET) == size + && UNSAFE.getInt(ptr + HASH_OFFSET) == hash + && isEqual(UNSAFE.getLong(ptr + ADDRESS_OFFSET), address, size)) { break; } - i = (i + 1) % SIZE; + idx = (idx + 1) % SIZE; + ptr = this.address + idx * ENTRY_SIZE; } - return i; + return ptr; } - private boolean isEqual(int index, int offset) { - for (int i = 0; i < sizes[index]; i++) { - if (mbb.get(offsets[index] + i) != mbb.get(offset + i)) { + private boolean isEqual(long address, long address2, byte size) { + for (int i = 0; i < size; i++) { + if (UNSAFE.getByte(address + i) != UNSAFE.getByte(address2 + i)) { return false; } } return true; } - private String createString(int offset, int size) { + private String createString(long address, byte size) { byte[] arr = new byte[size]; for (int i = 0; i < size; i++) { - arr[i] = mbb.get(offset + i); + arr[i] = UNSAFE.getByte(address + i); } return new String(arr); }