diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java index 789db73..2e7ea4c 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java @@ -25,14 +25,15 @@ import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.*; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; +import java.util.ArrayList; +import java.util.List; +import java.util.TreeMap; public class CalculateAverage_zerninv { private static final String FILE = "./measurements.txt"; - private static final int MIN_FILE_SIZE = 1024 * 1024 * 16; + private static final int L3_CACHE_SIZE = 128 * 1024 * 1024; + private static final int CORES = Runtime.getRuntime().availableProcessors(); + private static final int CHUNK_SIZE = (L3_CACHE_SIZE - MeasurementContainer.SIZE * MeasurementContainer.ENTRY_SIZE * CORES) / CORES - 1024 * CORES; // #.## private static final int THREE_DIGITS_MASK = 0x2e0000; @@ -48,47 +49,48 @@ public class CalculateAverage_zerninv { private static final Unsafe UNSAFE = initUnsafe(); - public static void main(String[] args) throws IOException { - var results = new HashMap(); + public static void main(String[] args) throws IOException, InterruptedException { try (var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { var fileSize = channel.size(); - var memorySegment = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()); - long address = memorySegment.address(); - var cores = Runtime.getRuntime().availableProcessors(); - var minChunkSize = fileSize < MIN_FILE_SIZE ? fileSize : fileSize / cores; - var chunks = splitByChunks(address, address + fileSize, minChunkSize); + var minChunkSize = Math.min(fileSize, CHUNK_SIZE); - var executor = Executors.newFixedThreadPool(cores); - List>> fResults = new ArrayList<>(); - for (int i = 1; i < chunks.size(); i++) { - final long prev = chunks.get(i - 1); - final long curr = chunks.get(i); - fResults.add(executor.submit(() -> calcForChunk(prev, curr))); + var tasks = new TaskThread[CORES]; + for (int i = 0; i < tasks.length; i++) { + tasks[i] = new TaskThread(new MeasurementContainer(), (int) (fileSize / minChunkSize / CORES + 1)); } - fResults.forEach(f -> { - try { - f.get().forEach((key, value) -> { - var result = results.get(key); - if (result != null) { - result.merge(value); - } - else { - results.put(key, value); - } - }); - } - catch (InterruptedException | ExecutionException e) { - e.printStackTrace(); - } - }); - executor.shutdown(); - } + var memorySegment = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()); + var address = memorySegment.address(); + var chunks = splitByChunks(address, address + fileSize, minChunkSize); + for (int i = 0; i < chunks.size() - 1; i++) { + var task = tasks[i % CORES]; + task.addChunk(chunks.get(i), chunks.get(i + 1)); + } - var bos = new BufferedOutputStream(System.out); - bos.write(new TreeMap<>(results).toString().getBytes(StandardCharsets.UTF_8)); - bos.write('\n'); - bos.flush(); + for (var task : tasks) { + task.start(); + } + + var results = new TreeMap(); + for (var task : tasks) { + task.join(); + task.measurements() + .forEach(measurement -> { + var aggr = results.get(measurement.station()); + if (aggr == null) { + results.put(measurement.station(), measurement.aggregation()); + } + else { + aggr.merge(measurement.aggregation()); + } + }); + } + + var bos = new BufferedOutputStream(System.out); + bos.write(new TreeMap<>(results).toString().getBytes(StandardCharsets.UTF_8)); + bos.write('\n'); + bos.flush(); + } } private static Unsafe initUnsafe() { @@ -103,7 +105,7 @@ public class CalculateAverage_zerninv { } private static List splitByChunks(long address, long end, long minChunkSize) { - List result = new ArrayList<>(); + List result = new ArrayList<>((int) ((end - address) / minChunkSize + 1)); result.add(address); while (address < end) { address += Math.min(end - address, minChunkSize); @@ -114,60 +116,20 @@ public class CalculateAverage_zerninv { return result; } - private static Map calcForChunk(long offset, long end) { - var results = new MeasurementContainer(); - - long cityOffset; - int hashCode, temperature, word; - byte cityNameSize, b; - - while (offset < end) { - cityOffset = offset; - hashCode = 0; - while ((b = UNSAFE.getByte(offset++)) != DELIMITER) { - hashCode = hashCode * 31 + b; - } - cityNameSize = (byte) (offset - cityOffset - 1); - - word = UNSAFE.getInt(offset); - offset += 4; - - if ((word & TWO_NEGATIVE_DIGITS_MASK) == TWO_NEGATIVE_DIGITS_MASK) { - word >>>= 8; - temperature = ZERO * 11 - ((word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK)); - } - else if ((word & THREE_DIGITS_MASK) == THREE_DIGITS_MASK) { - temperature = (word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK) - ZERO * 111; - } - else if ((word & TWO_DIGITS_MASK) == TWO_DIGITS_MASK) { - temperature = (word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK) - ZERO * 11; - offset--; - } - else { - // #.##- - word = (word >>> 8) | (UNSAFE.getByte(offset++) << 24); - temperature = ZERO * 111 - ((word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK)); - } - offset++; - results.put(cityOffset, cityNameSize, hashCode, (short) temperature); - } - return results.toStringMap(); - } - - private static final class MeasurementAggregation { + private static final class TemperatureAggregation { private long sum; private int count; private short min; private short max; - public MeasurementAggregation(long sum, int count, short min, short max) { + public TemperatureAggregation(long sum, int count, short min, short max) { this.sum = sum; this.count = count; this.min = min; this.max = max; } - public void merge(MeasurementAggregation o) { + public void merge(TemperatureAggregation o) { if (o == null) { return; } @@ -183,6 +145,9 @@ public class CalculateAverage_zerninv { } } + private record Measurement(String station, TemperatureAggregation aggregation) { + } + private static final class MeasurementContainer { private static final int SIZE = 1024 * 16; @@ -235,26 +200,26 @@ public class CalculateAverage_zerninv { } } - public Map toStringMap() { - var result = new HashMap(); + public List measurements() { + var result = new ArrayList(1000); int count; for (int i = 0; i < SIZE; i++) { long ptr = this.address + i * ENTRY_SIZE; count = UNSAFE.getInt(ptr + COUNT_OFFSET); if (count != 0) { - var measurements = new MeasurementAggregation( + var measurements = new TemperatureAggregation( UNSAFE.getLong(ptr + SUM_OFFSET), count, UNSAFE.getShort(ptr + MIN_OFFSET), UNSAFE.getShort(ptr + MAX_OFFSET)); var key = createString(UNSAFE.getLong(ptr + ADDRESS_OFFSET), UNSAFE.getByte(ptr + SIZE_OFFSET)); - result.put(key, measurements); + result.add(new Measurement(key, measurements)); } } return result; } - private boolean isEqual(long address, long address2, byte size) { + private static boolean isEqual(long address, long address2, byte size) { for (int i = 0; i < size; i++) { if (UNSAFE.getByte(address + i) != UNSAFE.getByte(address2 + i)) { return false; @@ -271,4 +236,69 @@ public class CalculateAverage_zerninv { return new String(arr); } } -} \ No newline at end of file + + private static class TaskThread extends Thread { + private final MeasurementContainer container; + private final List begins; + private final List ends; + + private TaskThread(MeasurementContainer container, int chunks) { + this.container = container; + this.begins = new ArrayList<>(chunks); + this.ends = new ArrayList<>(chunks); + } + + public void addChunk(long begin, long end) { + begins.add(begin); + ends.add(end); + } + + @Override + public void run() { + for (int i = 0; i < begins.size(); i++) { + calcForChunk(begins.get(i), ends.get(i)); + } + } + + public List measurements() { + return container.measurements(); + } + + private void calcForChunk(long offset, long end) { + long cityOffset; + int hashCode, temperature, word; + byte cityNameSize, b; + + while (offset < end) { + cityOffset = offset; + hashCode = 0; + while ((b = UNSAFE.getByte(offset++)) != DELIMITER) { + hashCode = hashCode * 31 + b; + } + cityNameSize = (byte) (offset - cityOffset - 1); + + word = UNSAFE.getInt(offset); + offset += 4; + + if ((word & TWO_NEGATIVE_DIGITS_MASK) == TWO_NEGATIVE_DIGITS_MASK) { + word >>>= 8; + temperature = ZERO * 11 - ((word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK)); + } + else if ((word & THREE_DIGITS_MASK) == THREE_DIGITS_MASK) { + temperature = (word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK) - ZERO * 111; + } + else if ((word & TWO_DIGITS_MASK) == TWO_DIGITS_MASK) { + temperature = (word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK) - ZERO * 11; + offset--; + } + else { + // #.##- + word = (word >>> 8) | (UNSAFE.getByte(offset++) << 24); + temperature = ZERO * 111 - ((word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK)); + } + offset++; + container.put(cityOffset, cityNameSize, hashCode, (short) temperature); + } + } + } +}