diff --git a/calculate_average_yavuztas.sh b/calculate_average_yavuztas.sh index 773e1de..bfa7b10 100755 --- a/calculate_average_yavuztas.sh +++ b/calculate_average_yavuztas.sh @@ -15,5 +15,5 @@ # limitations under the License. # -JAVA_OPTS="-Xms1g -Xmx1g" +JAVA_OPTS="-Xms128m -Xmx128m -XX:MaxGCPauseMillis=1 -XX:-AlwaysPreTouch -XX:+UseSerialGC --enable-preview" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yavuztas diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java b/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java index eb3d191..e33fe7e 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java @@ -15,77 +15,80 @@ */ package dev.morling.onebrc; -import java.io.Closeable; +import sun.misc.Unsafe; + import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.reflect.Field; import java.nio.ByteBuffer; -import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.HashMap; -import java.util.Map; import java.util.TreeMap; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.function.BiConsumer; +import java.util.function.Consumer; public class CalculateAverage_yavuztas { private static final Path FILE = Path.of("./measurements.txt"); - static class Measurement { + private static final Unsafe UNSAFE = unsafe(); - // Only accessed by a single thread, so it is safe to share - private static final StringBuilder STRING_BUILDER = new StringBuilder(14); - - private int min; // calculations over int is faster than double, we convert to double in the end only once - private int max; - private long sum; - private long count = 1; - - public Measurement(int initial) { - this.min = initial; - this.max = initial; - this.sum = initial; + // Tried all there: MappedByteBuffer, MemorySegment and Unsafe + // Accessing the memory using Unsafe is still the fastest in my experience + private static Unsafe unsafe() { + try { + final Field f = Unsafe.class.getDeclaredField("theUnsafe"); + f.setAccessible(true); + return (Unsafe) f.get(null); } - - public String toString() { - STRING_BUILDER.setLength(0); // clear the builder to reuse - STRING_BUILDER.append(this.min / 10.0); // convert to double while generating the string output - STRING_BUILDER.append("/"); - STRING_BUILDER.append(round((this.sum / 10.0) / this.count)); - STRING_BUILDER.append("/"); - STRING_BUILDER.append(this.max / 10.0); - return STRING_BUILDER.toString(); - } - - private double round(double value) { - return Math.round(value * 10.0) / 10.0; + catch (Exception e) { + throw new RuntimeException(e); } } - static class KeyBuffer { + // Only one object, both for measurements and keys, less object creation in hotpots is always faster + static class Record { - ByteBuffer buffer; + // keep memory starting address for each segment + // since we use Unsafe, this is enough to align and fetch the data + long segment; + int start; int length; int hash; - public KeyBuffer(ByteBuffer buffer, int length, int hash) { - this.buffer = buffer; + private int min = 1000; // calculations over int is faster than double, we convert to double in the end only once + private int max = -1000; + private long sum; + private long count; + + public Record(long segment, int start, int length, int hash) { + this.segment = segment; + this.start = start; this.length = length; this.hash = hash; } @Override public boolean equals(Object o) { - final KeyBuffer keyBuffer = (KeyBuffer) o; - if (this.length != keyBuffer.length || this.hash != keyBuffer.hash) + final Record record = (Record) o; + return equals(record.segment, record.start, record.length, record.hash); + } + + /** + * Stateless equals, no Record object needed + */ + public boolean equals(long segment, int start, int length, int hash) { + if (this.length != length || this.hash != hash) return false; - return this.buffer.equals(keyBuffer.buffer); + int i = 0; // bytes mismatch check + while (i < this.length + && UNSAFE.getByte(this.segment + this.start + i) == UNSAFE.getByte(segment + start + i)) { + i++; + } + return i == this.length; } @Override @@ -96,219 +99,284 @@ public class CalculateAverage_yavuztas { @Override public String toString() { final byte[] bytes = new byte[this.length]; - this.buffer.get(bytes); - return new String(bytes, 0, this.length, StandardCharsets.UTF_8); + int i = 0; + while (i < this.length) { + bytes[i] = UNSAFE.getByte(this.segment + this.start + i++); + } + + return new String(bytes, StandardCharsets.UTF_8); + } + + public Record collect(int temp) { + this.min = Math.min(this.min, temp); + this.max = Math.max(this.max, temp); + this.sum += temp; + this.count++; + return this; + } + + public void merge(Record other) { + this.min = Math.min(this.min, other.min); + this.max = Math.max(this.max, other.max); + this.sum += other.sum; + this.count += other.count; + } + + public String measurements() { + // here is only executed once for each unique key, so StringBuilder creation doesn't harm + final StringBuilder sb = new StringBuilder(14); + sb.append(this.min / 10.0); + sb.append("/"); + sb.append(round((this.sum / 10.0) / this.count)); + sb.append("/"); + sb.append(this.max / 10.0); + return sb.toString(); } } - static class FixedRegionDataAccessor { + // Inspired by @spullara - customized hashmap on purpose + // The main difference is we hold only one array instead of two + static class RecordMap { - long startPos; - long size; - ByteBuffer buffer; - int position; // relative + static final int SIZE = 1 << 15; // 32k - bigger bucket size less collisions + static final int BITMASK = SIZE - 1; + Record[] keys = new Record[SIZE]; - public FixedRegionDataAccessor(long startPos, long size, ByteBuffer buffer) { + static int hashBucket(int hash) { + hash = hash ^ (hash >>> 16); // naive bit spreading but surprisingly decreases collision :) + return hash & BITMASK; // fast modulo, to find bucket + } + + void putAndCollect(long segment, int start, int length, int hash, int temp) { + int bucket = hashBucket(hash); + Record existing = this.keys[bucket]; + if (existing == null) { + this.keys[bucket] = new Record(segment, start, length, hash) + .collect(temp); + return; + } + + if (!existing.equals(segment, start, length, hash)) { + // collision, linear probing to find a slot + while ((existing = this.keys[++bucket & BITMASK]) != null && !existing.equals(segment, start, length, hash)) { + // can be stuck here if all the buckets are full :( + // However, since the data set is max 10K (unique) this shouldn't happen + // So, I'm happily leave here branchless :) + } + if (existing == null) { + this.keys[bucket & BITMASK] = new Record(segment, start, length, hash) + .collect(temp); + return; + } + existing.collect(temp); + } + else { + existing.collect(temp); + } + } + + void putOrMerge(Record key) { + int bucket = hashBucket(key.hash); + Record existing = this.keys[bucket]; + if (existing == null) { + this.keys[bucket] = key; + return; + } + + if (!existing.equals(key)) { + // collision, linear probing to find a slot + while ((existing = this.keys[++bucket & BITMASK]) != null && !existing.equals(key)) { + // can be stuck here if all the buckets are full :( + // However, since the data set is max 10K (unique keys) this shouldn't happen + // So, I'm happily leave here branchless :) + } + if (existing == null) { + this.keys[bucket & BITMASK] = key; + return; + } + existing.merge(key); + } + else { + existing.merge(key); + } + } + + void forEach(Consumer consumer) { + int pos = 0; + Record key; + while (pos < this.keys.length) { + if ((key = this.keys[pos++]) == null) { + continue; + } + consumer.accept(key); + } + } + + void merge(RecordMap other) { + other.forEach(this::putOrMerge); + } + + } + + // One actor for one thread, no synchronization + static class RegionActor { + + final FileChannel channel; + final long startPos; + final int size; + final RecordMap map = new RecordMap(); + long segmentAddress; + int position; + Thread runner; // each actor has its own thread + + public RegionActor(FileChannel channel, long startPos, int size) { + this.channel = channel; this.startPos = startPos; this.size = size; - this.buffer = buffer; } - void traverse(BiConsumer consumer) { - int keyHash; - int length; - while (this.buffer.hasRemaining()) { - - this.position = this.buffer.position(); // save line start pos - - byte b; - keyHash = 0; - length = 0; - while ((b = this.buffer.get()) != ';') { // read until semicolon - keyHash = 31 * keyHash + b; // calculate key hash ahead, eleminates one more loop later - length++; + void accumulate() { + this.runner = new Thread(() -> { + try { + // get the segment memory address, this is the only thing we need for Unsafe + this.segmentAddress = this.channel.map(FileChannel.MapMode.READ_ONLY, this.startPos, this.size, Arena.global()).address(); + } + catch (IOException e) { + // no-op - skip intentionally, no handling for the purpose of this challenge } - final ByteBuffer station = this.buffer.slice(this.position, length); - final KeyBuffer key = new KeyBuffer(station, length, keyHash); - - this.buffer.mark(); // semicolon pos - skip(3); // skip more since minimum temperature length is 3 - length = 4; // +1 for semicolon - - while (this.buffer.get() != '\n') { - length++; // read until linebreak - // TODO how to read temperature here - } - - this.buffer.reset(); // set to after semicolon - consumer.accept(key, readTemperature(length)); - } - } - - Map accumulate(Map initial) { - - traverse((station, temperature) -> { - initial.compute(station, (k, m) -> { - if (m == null) { - return new Measurement(temperature); + int start; + int keyHash; + int length; + while (this.position < this.size) { + byte b; + start = this.position; // save line start position + keyHash = UNSAFE.getByte(this.segmentAddress + this.position++); // first byte is guaranteed not to be ';' + length = 1; // min key length + while ((b = UNSAFE.getByte(this.segmentAddress + this.position++)) != ';') { // read until semicolon + keyHash = calculateHash(keyHash, b); // calculate key hash ahead, eleminates one more loop later + length++; } - // aggregate - m.min = Math.min(m.min, temperature); - m.max = Math.max(m.max, temperature); - m.sum += temperature; - m.count++; - return m; - }); + + final int temp = readTemperature(); + this.map.putAndCollect(this.segmentAddress, start, length, keyHash, temp); + + this.position++; // skip linebreak + } }); - - return initial; + this.runner.start(); } - // caching Math.pow calculation improves a lot! - // interestingly, instance field access is much faster than static field access - final int[] powerOfTenCache = new int[]{ 1, 10, 100 }; + static int calculateHash(int hash, int b) { + return 31 * hash + b; + } - int readTemperature(int length) { + // 1. Inspired by @yemreinci - Reading temparature value without Double.parse + // 2. Inspired by @obourgain - Fetching first 4 bytes ahead, then masking + int readTemperature() { int temp = 0; - final byte b1 = this.buffer.get(); // get first byte + // read 4 bytes ahead + final int first4 = UNSAFE.getInt(this.segmentAddress + this.position); + this.position += 3; - int digits = length - 4; // digit position - final boolean negative = b1 == '-'; - if (!negative) { - temp += this.powerOfTenCache[digits + 1] * (b1 - 48); // add first digit ahead - } - - byte b; - while ((b = this.buffer.get()) != '.') { // read until dot - temp += this.powerOfTenCache[digits--] * (b - 48); - } - b = this.buffer.get(); // read after dot, only one digit no loop - temp += this.powerOfTenCache[digits] * (b - 48); - this.buffer.get(); // skip line break - - return (negative) ? -temp : temp; - } - - ByteBuffer getKeyRef(int length) { - final ByteBuffer slice = this.buffer.slice().limit(length - 1); - skip(length); - return slice; - } - - void skip(int length) { - final int pos = this.buffer.position(); - this.buffer.position(pos + length); - } - - } - - static class FastDataReader implements Closeable { - - private final FixedRegionDataAccessor[] accessors; - private final ExecutorService mergerThread; - private final ExecutorService accessorPool; - - public FastDataReader(Path path) throws IOException { - var concurrency = Runtime.getRuntime().availableProcessors(); - final long fileSize = Files.size(path); - long regionSize = fileSize / concurrency; - - // handling extreme cases - while (regionSize > Integer.MAX_VALUE) { - concurrency *= 2; - regionSize = fileSize / concurrency; - } - if (regionSize <= 256) { // small file, no need concurrency - concurrency = 1; - regionSize = fileSize; - } - - long startPosition = 0; - this.accessors = new FixedRegionDataAccessor[concurrency]; - for (int i = 0; i < concurrency - 1; i++) { - // map regions - try (final FileChannel channel = (FileChannel) Files.newByteChannel(path, StandardOpenOption.READ)) { - final long maxSize = startPosition + regionSize > fileSize ? fileSize - startPosition : regionSize; - final MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, startPosition, maxSize); - this.accessors[i] = new FixedRegionDataAccessor(startPosition, maxSize, buffer); - // adjust positions back and forth until we find a linebreak! - final int closestPos = findClosestLineEnd((int) maxSize - 1, buffer); - buffer.limit(closestPos + 1); - startPosition += closestPos + 1; + final byte b1 = (byte) first4; // first byte + final byte b2 = (byte) ((first4 >> 8) & 0xFF); // second byte + final byte b3 = (byte) ((first4 >> 16) & 0xFF); // third byte + if (b1 == '-') { + if (b3 == '.') { + temp -= 10 * (b2 - '0') + (byte) ((first4 >> 24) & 0xFF) - '0'; // fourth byte + this.position++; + } + else { + this.position++; // skip dot + temp -= 100 * (b2 - '0') + 10 * (b3 - '0') + UNSAFE.getByte(this.segmentAddress + this.position++) - '0'; // fifth byte } } - // map the last region - try (final FileChannel channel = (FileChannel) Files.newByteChannel(path, StandardOpenOption.READ)) { - final long maxSize = fileSize - startPosition; // last region will take the rest - final MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, startPosition, maxSize); - this.accessors[concurrency - 1] = new FixedRegionDataAccessor(startPosition, maxSize, buffer); + else { + if (b2 == '.') { + temp = 10 * (b1 - '0') + b3 - '0'; + } + else { + temp = 100 * (b1 - '0') + 10 * (b2 - '0') + (byte) ((first4 >> 24) & 0xFF) - '0'; // fourth byte + this.position++; + } } - // create executors - this.mergerThread = Executors.newSingleThreadExecutor(); - this.accessorPool = Executors.newFixedThreadPool(concurrency); - } - void readAndCollect(Map output) { - for (final FixedRegionDataAccessor accessor : this.accessors) { - this.accessorPool.submit(() -> { - final Map partial = accessor.accumulate(new HashMap<>(1 << 10, 1)); // aka 1k - this.mergerThread.submit(() -> mergeMaps(output, partial)); - }); - } - } - - @Override - public void close() { - try { - this.accessorPool.shutdown(); - this.accessorPool.awaitTermination(60, TimeUnit.SECONDS); - this.mergerThread.shutdown(); - this.mergerThread.awaitTermination(60, TimeUnit.SECONDS); - } - catch (Exception e) { - this.accessorPool.shutdownNow(); - this.mergerThread.shutdownNow(); - } + return temp; } /** - * Scans the given buffer to the left + * blocks until the map is fully collected */ - private static int findClosestLineEnd(int regionSize, ByteBuffer buffer) { - int position = regionSize; - int left = regionSize; - while (buffer.get(position) != '\n') { - position = --left; + RecordMap get() throws InterruptedException { + this.runner.join(); + return this.map; + } + } + + private static double round(double value) { + return Math.round(value * 10.0) / 10.0; + } + + /** + * Scans the given buffer to the left + */ + private static long findClosestLineEnd(long start, int size, FileChannel channel) throws IOException { + final long position = start + size; + final long left = Math.max(position - 101, 0); + final ByteBuffer buffer = ByteBuffer.allocate(101); // enough size to find at least one '\n' + if (channel.read(buffer.clear(), left) != -1) { + int bufferPos = buffer.position() - 1; + while (buffer.get(bufferPos) != '\n') { + bufferPos--; + size--; } - return position; } - - private static Map mergeMaps(Map map1, Map map2) { - map2.forEach((s, measurement) -> { - map1.merge(s, measurement, (m1, m2) -> { - m1.min = Math.min(m1.min, m2.min); - m1.max = Math.max(m1.max, m2.max); - m1.sum += m2.sum; - m1.count += m2.count; - return m1; - }); - }); - - return map1; - } - + return size; } public static void main(String[] args) throws IOException, InterruptedException { - final Map output = new HashMap<>(1 << 10, 1); // aka 1k - try (final FastDataReader reader = new FastDataReader(FILE)) { - reader.readAndCollect(output); + + var concurrency = Runtime.getRuntime().availableProcessors(); + final long fileSize = Files.size(FILE); + long regionSize = fileSize / concurrency; + + // handling extreme cases + while (regionSize > Integer.MAX_VALUE) { + concurrency *= 2; + regionSize /= 2; + } + if (fileSize <= 1 << 20) { // small file (1mb), no need concurrency + concurrency = 1; + regionSize = fileSize; } - final TreeMap sorted = new TreeMap<>(); - output.forEach((s, measurement) -> sorted.put(s.toString(), measurement)); + long startPos = 0; + final FileChannel channel = (FileChannel) Files.newByteChannel(FILE, StandardOpenOption.READ); + final RegionActor[] actors = new RegionActor[concurrency]; + for (int i = 0; i < concurrency; i++) { + // calculate boundaries + long maxSize = (startPos + regionSize > fileSize) ? fileSize - startPos : regionSize; + // shift position to back until we find a linebreak + maxSize = findClosestLineEnd(startPos, (int) maxSize, channel); + + final RegionActor region = (actors[i] = new RegionActor(channel, startPos, (int) maxSize)); + region.accumulate(); + + startPos += maxSize; + } + + final RecordMap output = new RecordMap(); // output to merge all regions + for (RegionActor actor : actors) { + final RecordMap partial = actor.get(); // blocks until get the result + output.merge(partial); + } + + // sort and print the result + final TreeMap sorted = new TreeMap<>(); + output.forEach(key -> sorted.put(key.toString(), key.measurements())); System.out.println(sorted); + } }