diff --git a/calculate_average_iziamos.sh b/calculate_average_iziamos.sh index 9b18b1d..7ce3ff1 100755 --- a/calculate_average_iziamos.sh +++ b/calculate_average_iziamos.sh @@ -15,6 +15,5 @@ # limitations under the License. # - -JAVA_OPTS="--enable-preview" +JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -Xms16m -Xmx16m -XX:-AlwaysPreTouch -XX:-TieredCompilation -XX:CICompilerCount=1" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_iziamos diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_iziamos.java b/src/main/java/dev/morling/onebrc/CalculateAverage_iziamos.java index 53dc3aa..c0358b9 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_iziamos.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_iziamos.java @@ -15,57 +15,76 @@ */ package dev.morling.onebrc; +import sun.misc.Unsafe; + import java.io.IOException; -import java.nio.ByteBuffer; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.reflect.Field; import java.nio.channels.FileChannel; import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.StandardOpenOption; -import java.util.Arrays; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.locks.ReentrantLock; -import java.util.stream.LongStream; +import static dev.morling.onebrc.CalculateAverage_iziamos.ByteBackedResultSet.mask; import static java.nio.channels.FileChannel.MapMode.READ_ONLY; import static java.nio.charset.StandardCharsets.UTF_8; +import static java.nio.file.StandardOpenOption.READ; public class CalculateAverage_iziamos { + private static final Unsafe UNSAFE; + private static final String FILE = "./measurements.txt"; - private static final int CHUNK_SIZE = 8 * 1024 * 1024; - private static final int NAME_ARRAY_LENGTH = 103; - private static final int NAME_ARRAY_LENGTH_POSITION = NAME_ARRAY_LENGTH - 1; - private static final int NAME_ARRAY_HASHCODE_POSITION = NAME_ARRAY_LENGTH - 2; - private final static ReentrantLock mergeLock = new ReentrantLock(); + private static final Arena GLOBAL_ARENA = Arena.global(); + private final static MemorySegment WHOLE_FILE_SEGMENT; + private final static long FILE_SIZE; + private final static long BASE_POINTER; + private final static long END_POINTER; + + static { + try { + final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + UNSAFE = (Unsafe) theUnsafe.get(Unsafe.class); + + final var fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), READ); + WHOLE_FILE_SEGMENT = fileChannel.map(READ_ONLY, 0, fileChannel.size(), GLOBAL_ARENA); + + } + catch (final NoSuchFieldException | IllegalAccessException | IOException e) { + throw new RuntimeException(e); + } + + FILE_SIZE = WHOLE_FILE_SEGMENT.byteSize(); + BASE_POINTER = WHOLE_FILE_SEGMENT.address(); + END_POINTER = BASE_POINTER + FILE_SIZE; + } + private static final long CHUNK_SIZE = 64 * 1024 * 1024; + // private static final long CHUNK_SIZE = Long.MAX_VALUE; public static void main(String[] args) throws Exception { - // Thread.sleep(10000); - final var channel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ); + // Thread.sleep(10_000); - final var fileSize = channel.size(); - final long threadCount = 1 + fileSize / CHUNK_SIZE; - final ResultSet aggregate = new ResultSet(); + final long threadCount = 1 + FILE_SIZE / CHUNK_SIZE; - final CompletableFuture taskCompleteFutureThing = CompletableFuture.allOf(LongStream.range(0, threadCount) - .mapToObj(t -> processSegment(channel, t, CHUNK_SIZE, t == threadCount - 1) - .thenAccept(result -> mergeResults(aggregate, result))) - .toArray(CompletableFuture[]::new)); + final var processingFutures = new CompletableFuture[(int) threadCount]; + for (int i = 0; i < threadCount; ++i) { + processingFutures[i] = processSegment(i, CHUNK_SIZE); + } - taskCompleteFutureThing.join(); + final long aggregate = (long) processingFutures[0].get(); + for (int i = 1; i < processingFutures.length; i++) { + final long r = (long) processingFutures[i].get(); + ByteBackedResultSet.merge(aggregate, r); + } final Map output = new TreeMap<>(); - - aggregate.forEach((name, max, min, sum, count) -> output.put(nameToString(name), new ResultRow(min, (double) sum / count, max))); + ByteBackedResultSet.forEach(aggregate, + (name, min, max, sum, count) -> output.put(name, new ResultRow(min, (double) sum / count, max))); System.out.println(output); - // System.out.println(Arrays.stream(aggregate.counts).sum()); - } - - private static void mergeResults(final ResultSet aggregate, final ResultSet result) { - mergeLock.lock(); - aggregate.merge(result); - mergeLock.unlock(); } private record ResultRow(long min, double mean, long max) { @@ -74,118 +93,138 @@ public class CalculateAverage_iziamos { } private double formatLong(final long value) { - return value / 10.0; } + private double round(double value) { return Math.round(value) / 10.0; } } - private static CompletableFuture processSegment(final FileChannel channel, - final long chunkNumber, - final long size, - final boolean isLast) { - final var result = new CompletableFuture(); + private static CompletableFuture processSegment(final long chunkNumber, final long chunkSize) { + final var ret = new CompletableFuture(); Thread.ofVirtual().start(() -> { - try { - final long start = chunkNumber * size; - final long memoryMapSize = mapsize(channel.size(), start, size, isLast); - final ByteBuffer mmap = channel.map(READ_ONLY, start, memoryMapSize); - skipIncomplete(mmap, start); - result.complete(processEvents(mmap, isLast ? memoryMapSize : size)); - } - catch (IOException e) { - result.completeExceptionally(e); - } + final long relativeStart = chunkNumber * chunkSize; + final long absoluteStart = BASE_POINTER + relativeStart; + + final long absoluteEnd = computeAbsoluteEndWithSlack(absoluteStart + chunkSize); + final long startOffsetAfterSkipping = skipIncomplete(WHOLE_FILE_SEGMENT.address(), absoluteStart); + + final long result = processEvents(startOffsetAfterSkipping, absoluteEnd); + ret.complete(result); }); + return ret; + } + + private static long computeAbsoluteEndWithSlack(final long chunk) { + return Long.compareUnsigned(END_POINTER, chunk) > 0 ? chunk : END_POINTER; + } + + private static long skipIncomplete(final long basePointer, final long start) { + if (start == basePointer) { + return start; + } + for (long i = 0;; ++i) { + final byte b = UNSAFE.getByte(start + i); + if (b == '\n') { + return start + i + 1; + } + } + } + + private static long processEvents(final long start, final long limit) { + final long result = ByteBackedResultSet.createResultSet(); + scalarLoop(start, limit, result); return result; } - private static long mapsize(final long total, final long start, final long chunk, final boolean isLast) { - final long chunkWithSomeOverlap = chunk + 128; - if (isLast) { - return total - start; - } - - return chunkWithSomeOverlap; - } - - private static void skipIncomplete(final ByteBuffer buffer, final long start) { - if (start == 0) { - return; - } - for (byte b = buffer.get();; b = buffer.get()) { - if (b == '\n') - return; + private static void scalarLoop(final long start, final long limit, final long result) { + final var cursor = new ScalarLoopCursor(start, limit); + while (cursor.hasMore()) { + final long address = cursor.getCurrentAddress(); + final int length = cursor.getStringLength(); + final int hash = cursor.getHash(); + final int value = cursor.getCurrentValue(); + ByteBackedResultSet.put(result, address, length, hash, value); } } - private static ResultSet processEvents(final ByteBuffer buffer, final long limit) { - final var result = new ResultSet(); - int[] nameBuffer = new int[NAME_ARRAY_LENGTH]; - while (buffer.hasRemaining() && buffer.position() <= limit) { - nameBuffer = processEvent(buffer, nameBuffer, result); + public static class ScalarLoopCursor { + private long pointer; + private final long limit; + + private int hash = 0; + + public ScalarLoopCursor(final long pointer, final long limit) { + this.pointer = pointer; + this.limit = limit; } - return result; - } - private static int[] processEvent(final ByteBuffer buffer, final int[] nameBuffer, final ResultSet map) { - parseName(buffer, nameBuffer); - final int value = readValue(buffer); - - return map.put(nameBuffer, value) ? new int[NAME_ARRAY_LENGTH] : nameBuffer; - } - - private static void parseName(final ByteBuffer buffer, final int[] name) { - byte i = 0; - int hash = 0; - for (byte b = buffer.get(); b != ';'; b = buffer.get(), ++i) { - writeByte(name, i, b); - hash = 31 * hash + b; + public long getCurrentAddress() { + return pointer; } - setNameArrayLength(name, i); - setNameArrayHash(name, hash); - } - private static void writeByte(final int[] name, final int i, final byte b) { - name[i] = b; - } + public int getStringLength() { + int strLen = 0; + hash = 0; - private static int readValue(final ByteBuffer buffer) { - final byte first = buffer.get(); - final boolean isNegative = first == '-'; + byte b = UNSAFE.getByte(pointer); + for (; b != ';'; ++strLen, b = UNSAFE.getByte(pointer + strLen)) { + hash += b << strLen; + } + pointer += strLen + 1; - int value = digitCharToInt(isNegative ? buffer.get() : first); + return strLen; + } - final byte second = buffer.get(); - value = addSecondDigitIfPresent(buffer, second, value); - value = addDecimal(buffer, value); + public int getHash() { + return mask(hash); + } - consumeNewLine(buffer); - return isNegative ? -value : value; - } + public int getCurrentValue() { + final byte first = UNSAFE.getByte(pointer++); + final byte second = UNSAFE.getByte(pointer++); + final byte third = UNSAFE.getByte(pointer++); + final byte fourth = UNSAFE.getByte(pointer++); + final byte fifth = UNSAFE.getByte(pointer++); - private static void consumeNewLine(final ByteBuffer buffer) { - if (buffer.hasRemaining()) { - buffer.get(); + int value; + if (second == '.') { + // D.D\n + value = appendDigit(digitCharToInt(first), third); + pointer--; + return value; + } + else if (fourth == '.') { + // -DD.D\n + value = digitCharToInt(second); + value = appendDigit(value, third); + value = -appendDigit(value, fifth); + pointer++; + return value; + } + else if (first == '-') { + // -D.D\n + return -appendDigit(digitCharToInt(second), fourth); + } + else { + // DD.D\n + value = digitCharToInt(first); + value = appendDigit(value, second); + return appendDigit(value, fourth); + } + } + + public boolean hasMore() { + return pointer < limit; } } - private static int addDecimal(final ByteBuffer buffer, int value) { + private static int appendDigit(int value, final byte b) { value *= 10; - value += digitCharToInt(buffer.get()); - return value; - } - - private static int addSecondDigitIfPresent(final ByteBuffer buffer, final byte second, int value) { - if (second != '.') { - value *= 10; - value += digitCharToInt(second); - buffer.get(); - } + value += digitCharToInt(b); return value; } @@ -193,128 +232,150 @@ public class CalculateAverage_iziamos { return b - '0'; } - private interface ResultConsumer { - void consume(final int[] name, final long max, final long min, final long sum, final long count); + public interface ResultConsumer { + void consume(final String name, final int min, final int max, final long sum, final long count); } - private static class ResultSet { + static class ByteBackedResultSet { private static final int MAP_SIZE = 16384; private static final int MASK = MAP_SIZE - 1; + private static final long STRUCT_SIZE = 64; + private static final long BYTE_SIZE = MAP_SIZE * STRUCT_SIZE; + private static final long STRING_OFFSET = 0; + private static final long STRING_LEN_OFFSET = 8; + private static final long HASH_OFFSET = 12; + private static final long MIN_OFFSET = 16; + private static final long MAX_OFFSET = 20; + private static final long SUM_OFFSET = 24; + private static final long COUNT_OFFSET = 32; - private final int[][] names = new int[MAP_SIZE][]; - private final long[] maximums = new long[MAP_SIZE]; - private final long[] minimums = new long[MAP_SIZE]; - private final long[] sums = new long[MAP_SIZE]; - private final long[] counts = new long[MAP_SIZE]; - - ResultSet() { - if (Integer.bitCount(MAP_SIZE) != 1) { - throw new RuntimeException("blah"); - } - Arrays.fill(maximums, Long.MIN_VALUE); - Arrays.fill(minimums, Long.MAX_VALUE); + public static long createResultSet() { + final long baseAddress = UNSAFE.allocateMemory(BYTE_SIZE); + UNSAFE.setMemory(baseAddress, BYTE_SIZE, (byte) 0); + return baseAddress; } - /** - * @return true if the name is new - */ - public boolean put(final int[] name, long value) { - final int hash = name[NAME_ARRAY_HASHCODE_POSITION]; - final int slot = findSlot(hash, name); - return insert(slot, name, value); + public static void put(final long baseAddress, final long address, final int length, final int hash, final int value) { + final long slot = findSlot(baseAddress, hash, address, length); + final long structBase = baseAddress + (slot * STRUCT_SIZE); + + final int min = UNSAFE.getInt(structBase + MIN_OFFSET); + final int max = UNSAFE.getInt(structBase + MAX_OFFSET); + final long sum = UNSAFE.getLong(structBase + SUM_OFFSET); + final long count = UNSAFE.getLong(structBase + COUNT_OFFSET); + + UNSAFE.putLong(structBase, address); + UNSAFE.putInt(structBase + STRING_LEN_OFFSET, length); + UNSAFE.putInt(structBase + HASH_OFFSET, hash); + + UNSAFE.putInt(structBase + MIN_OFFSET, Math.min(value, min)); + UNSAFE.putInt(structBase + MAX_OFFSET, Math.max(value, max)); + UNSAFE.putLong(structBase + SUM_OFFSET, sum + value); + UNSAFE.putLong(structBase + COUNT_OFFSET, count + 1); } - public void forEach(final ResultConsumer consumer) { - for (int i = 0; i < ResultSet.MAP_SIZE; ++i) { - final int[] name = names[i]; - - if (name == null) { + public static void forEach(final long baseAddress, final ResultConsumer resultConsumer) { + for (long i = 0; i < BYTE_SIZE; i += STRUCT_SIZE) { + final long structBase = baseAddress + i; + final long stringBase = UNSAFE.getLong(structBase); + if (stringBase == 0) { continue; } - consumer.consume(name, maximums[i], minimums[i], sums[i], counts[i]); + final int min = UNSAFE.getInt(structBase + MIN_OFFSET); + final int max = UNSAFE.getInt(structBase + MAX_OFFSET); + final long sum = UNSAFE.getLong(structBase + SUM_OFFSET); + final long count = UNSAFE.getLong(structBase + COUNT_OFFSET); + + final int strLen = UNSAFE.getInt(structBase + STRING_LEN_OFFSET); + final byte[] bytes = new byte[strLen]; + for (int j = 0; j < strLen; ++j) { + bytes[j] = UNSAFE.getByte(stringBase + j); + } + + resultConsumer.consume(new String(bytes, UTF_8), min, max, sum, count); } } - public void merge(final ResultSet other) { - other.forEach((name, max, min, sum, count) -> { - final int hash = name[NAME_ARRAY_HASHCODE_POSITION]; - final int slot = findSlot(hash, name); - mergeValues(slot, name, min, max, sum, count); - }); + public static void merge(final long baseAddress, final long other) { + for (long i = 0; i < BYTE_SIZE; i += STRUCT_SIZE) { + final long otherStructBase = other + i; + if (UNSAFE.getLong(otherStructBase) == 0) { + continue; + } + final long otherStringStart = UNSAFE.getLong(otherStructBase); + final int otherStringLength = UNSAFE.getInt(otherStructBase + STRING_LEN_OFFSET); + final int otherStringHash = UNSAFE.getInt(otherStructBase + HASH_OFFSET); + + final long slot = findSlot(baseAddress, otherStringHash, otherStringStart, otherStringLength); + + final long thisStructBase = baseAddress + (slot * STRUCT_SIZE); + + final int min = UNSAFE.getInt(thisStructBase + MIN_OFFSET); + final int max = UNSAFE.getInt(thisStructBase + MAX_OFFSET); + final long sum = UNSAFE.getLong(thisStructBase + SUM_OFFSET); + final long count = UNSAFE.getLong(thisStructBase + COUNT_OFFSET); + + final int otherMin = UNSAFE.getInt(otherStructBase + MIN_OFFSET); + final int otherMax = UNSAFE.getInt(otherStructBase + MAX_OFFSET); + final long otherSum = UNSAFE.getLong(otherStructBase + SUM_OFFSET); + final long otherCount = UNSAFE.getLong(otherStructBase + COUNT_OFFSET); + + UNSAFE.putLong(thisStructBase, otherStringStart); + UNSAFE.putInt(thisStructBase + STRING_LEN_OFFSET, otherStringLength); + UNSAFE.putInt(thisStructBase + HASH_OFFSET, otherStringHash); + + UNSAFE.putInt(thisStructBase + MIN_OFFSET, Math.min(otherMin, min)); + UNSAFE.putInt(thisStructBase + MAX_OFFSET, Math.max(otherMax, max)); + UNSAFE.putLong(thisStructBase + SUM_OFFSET, sum + otherSum); + UNSAFE.putLong(thisStructBase + COUNT_OFFSET, count + otherCount); + } } - private int findSlot(final int hash, final int[] name) { - for (int slot = mask(hash);; slot = mask(++slot)) { - if (isCorrectSlot(name, slot)) { + private static int findSlot(final long baseAddress, + final int hash, + final long otherStringAddress, + final int otherStringLength) { + + for (int slot = hash;; slot = mask(++slot)) { + final long structBase = baseAddress + ((long) slot * STRUCT_SIZE); + final long nameStart = UNSAFE.getLong(structBase); + if (nameStart == 0) { + UNSAFE.putInt(structBase + MIN_OFFSET, Integer.MAX_VALUE); + UNSAFE.putInt(structBase + MAX_OFFSET, Integer.MIN_VALUE); + return slot; + } + + final int nameLength = UNSAFE.getInt(structBase + STRING_LEN_OFFSET); + if (stringEquals(nameStart, nameLength, otherStringAddress, otherStringLength)) { return slot; } } } - private boolean isCorrectSlot(final int[] name, final int slot) { - return names[slot] == null || nameArrayEquals(names[slot], name); - } - - private int mask(final long key) { - return (int) (key & MASK); - } - - private boolean insert(final int slot, final int[] name, final long value) { - final int[] currentValue = names[slot]; - updateValues(slot, value); - if (currentValue == null) { - names[slot] = name; - return true; + private static boolean stringEquals(final long thisNameAddress, final int thisStringLength, final long otherNameAddress, final long otherNameLength) { + if (thisStringLength != otherNameLength) { + return false; } - return false; + + int i = 0; + for (; i < thisStringLength - 3; i += 4) { + if (UNSAFE.getInt(thisNameAddress + i) != UNSAFE.getInt(otherNameAddress + i)) { + return false; + } + } + + final int remainingToCheck = thisStringLength - i; + final int finalBytesMask = ((1 << remainingToCheck * 8)) - 1; + final int thisLastWord = UNSAFE.getInt(thisNameAddress + i); + final int otherLastWord = UNSAFE.getInt(otherNameAddress + i); + + return 0 == ((thisLastWord ^ otherLastWord) & finalBytesMask); } - private void updateValues(final int slot, final long value) { - maximums[slot] = Math.max(maximums[slot], value); - minimums[slot] = Math.min(minimums[slot], value); - sums[slot] += value; - counts[slot]++; + public static int mask(final int value) { + return MASK & value; } - - private void mergeValues(final int slot, - final int[] name, - final long min, - final long max, - final long sum, - final long count) { - names[slot] = name; - maximums[slot] = Math.max(maximums[slot], max); - minimums[slot] = Math.min(minimums[slot], min); - sums[slot] += sum; - counts[slot] += count; - } - } - - private static boolean nameArrayEquals(final int[] a, final int[] b) { - return Arrays.equals(a, 0, getNameArrayLength(a), b, 0, getNameArrayLength(b)); - } - - private static int getNameArrayLength(final int[] name) { - return name[NAME_ARRAY_LENGTH_POSITION]; - } - - private static void setNameArrayLength(final int[] name, int length) { - name[NAME_ARRAY_LENGTH_POSITION] = length; - } - - private static void setNameArrayHash(final int[] name, int hash) { - name[NAME_ARRAY_HASHCODE_POSITION] = hash; - } - - private static String nameToString(final int[] name) { - final int nameArrayLength = getNameArrayLength(name); - final byte[] bytes = new byte[nameArrayLength]; - for (int i = 0; i < nameArrayLength; ++i) { - bytes[i] = (byte) name[i]; - } - - return new String(bytes, UTF_8); } }