diff --git a/calculate_average_artsiomkorzun.sh b/calculate_average_artsiomkorzun.sh index 7671ba3..f965cda 100755 --- a/calculate_average_artsiomkorzun.sh +++ b/calculate_average_artsiomkorzun.sh @@ -17,4 +17,6 @@ JAVA_OPTS="-XX:+UseParallelGC" +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_artsiomkorzun diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java index 5efd4a7..516a6ab 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java @@ -24,70 +24,51 @@ import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.Arrays; import java.util.Comparator; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -import java.util.stream.IntStream; public class CalculateAverage_artsiomkorzun { private static final Path FILE = Path.of("./measurements.txt"); private static final long FILE_SIZE = size(FILE); + private static final int PARALLELISM = Runtime.getRuntime().availableProcessors(); private static final int SEGMENT_SIZE = 16 * 1024 * 1024; private static final int SEGMENT_COUNT = (int) ((FILE_SIZE + SEGMENT_SIZE - 1) / SEGMENT_SIZE); private static final int SEGMENT_OVERLAP = 1024; public static void main(String[] args) throws Exception { - /* - * for (int i = 0; i < 10; i++) { - * long start = System.currentTimeMillis(); - * execute(); - * long end = System.currentTimeMillis(); - * System.err.println("Time: " + (end - start)); - * } - */ + // for (int i = 0; i < 10; i++) { + // long start = System.currentTimeMillis(); + // execute(); + // long end = System.currentTimeMillis(); + // System.err.println("Time: " + (end - start)); + // } execute(); } - private static void execute() { - Aggregates aggregates = IntStream.range(0, SEGMENT_COUNT) - .parallel() - .mapToObj(CalculateAverage_artsiomkorzun::aggregate) - .reduce(new Aggregates(), CalculateAverage_artsiomkorzun::merge) - .sort(); + private static void execute() throws Exception { + AtomicInteger counter = new AtomicInteger(); + AtomicReference result = new AtomicReference<>(); + Aggregator[] aggregators = new Aggregator[PARALLELISM]; + + for (int i = 0; i < aggregators.length; i++) { + aggregators[i] = new Aggregator(counter, result); + aggregators[i].start(); + } + + for (int i = 0; i < aggregators.length; i++) { + aggregators[i].join(); + } + + Aggregates aggregates = result.get(); + aggregates.sort(); print(aggregates); } - private static Aggregates aggregate(int segment) { - long position = (long) SEGMENT_SIZE * segment; - int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, FILE_SIZE - position); - int limit = Math.min(SEGMENT_SIZE, size - 1); - - MappedByteBuffer buffer = map(position, size); // leaking until gc - - if (position > 0) { - next(buffer); - } - - Aggregates aggregates = new Aggregates(); - Row row = new Row(); - - while (buffer.position() <= limit) { - parse(buffer, row); - aggregates.add(row); - } - - return aggregates; - } - - private static Aggregates merge(Aggregates lefts, Aggregates rights) { - Aggregates to = (lefts.size() < rights.size()) ? rights : lefts; - Aggregates from = (lefts.size() < rights.size()) ? lefts : rights; - from.visit(to::merge); - return to; - } - private static void print(Aggregates aggregates) { StringBuilder builder = new StringBuilder(aggregates.size() * 15 + 32); builder.append("{"); @@ -111,62 +92,11 @@ public class CalculateAverage_artsiomkorzun { } } - private static MappedByteBuffer map(long position, int size) { - try (FileChannel channel = FileChannel.open(FILE, StandardOpenOption.READ)) { - return channel.map(FileChannel.MapMode.READ_ONLY, position, size); // leaking until gc - } - catch (Throwable e) { - throw new RuntimeException(e); - } - } - - private static void next(ByteBuffer buffer) { - while (buffer.get() != '\n') { - // continue - } - } - - private static void parse(ByteBuffer buffer, Row row) { - int index = 0; - byte b; - - while ((b = buffer.get()) != ';') { - row.station[index++] = b; - } - - row.length = index; - - double value = 0; - double multiplier = 1; - - b = buffer.get(); - if (b == '-') { - multiplier = -1; - } - else { - assert b >= '0' && b <= '9'; - value = b - '0'; - } - - while ((b = buffer.get()) != '.') { - assert b >= '0' && b <= '9'; - value = 10 * value + (b - '0'); - } - - b = buffer.get(); - assert b >= '0' && b <= '9'; - value = 10 * value + (b - '0'); - - b = buffer.get(); - assert b == '\n'; - - row.temperature = value * multiplier; - } - private static class Row { final byte[] station = new byte[256]; int length; - double temperature; + int hash; + int temperature; @Override public String toString() { @@ -176,23 +106,25 @@ public class CalculateAverage_artsiomkorzun { private static class Aggregate implements Comparable { final byte[] station; - double min; - double max; - double sum; - double count; + final int hash; + int min; + int max; + long sum; + int count; - public Aggregate(byte[] station, int length, double temperature) { - this.station = Arrays.copyOf(station, length); - this.min = temperature; - this.max = temperature; - this.sum = temperature; + public Aggregate(Row row) { + this.station = Arrays.copyOf(row.station, row.length); + this.hash = row.hash; + this.min = row.temperature; + this.max = row.temperature; + this.sum = row.temperature; this.count = 1; } - public void add(double temperature) { - min = Math.min(min, temperature); - max = Math.max(max, temperature); - sum += temperature; + public void add(Row row) { + min = Math.min(min, row.temperature); + max = Math.max(max, row.temperature); + sum += row.temperature; count++; } @@ -223,7 +155,7 @@ public class CalculateAverage_artsiomkorzun { @Override public String toString() { - return new String(station) + "=" + round(min) + "/" + round(sum / count) + "/" + round(max); + return new String(station) + "=" + round(min) + "/" + round(1.0 * sum / count) + "/" + round(max); } private static double round(double v) { @@ -255,26 +187,21 @@ public class CalculateAverage_artsiomkorzun { } public void add(Row row) { - byte[] station = row.station; - int length = row.length; - double temperature = row.temperature; - - int hash = hash(station, length); - int index = hash & (aggregates.length - 1); + int index = row.hash & (aggregates.length - 1); while (true) { Aggregate aggregate = aggregates[index]; if (aggregate == null) { - aggregates[index] = new Aggregate(station, length, temperature); + aggregates[index] = new Aggregate(row); if (++size >= limit) { grow(); } break; } - if (equal(station, length, aggregate.station, aggregate.station.length)) { - aggregate.add(temperature); + if (row.hash == aggregate.hash && Arrays.equals(row.station, 0, row.length, aggregate.station, 0, aggregate.station.length)) { + aggregate.add(row); break; } @@ -283,10 +210,7 @@ public class CalculateAverage_artsiomkorzun { } public void merge(Aggregate right) { - byte[] station = right.station; - - int hash = hash(station, station.length); - int index = hash & (aggregates.length - 1); + int index = right.hash & (aggregates.length - 1); while (true) { Aggregate aggregate = aggregates[index]; @@ -299,7 +223,7 @@ public class CalculateAverage_artsiomkorzun { break; } - if (equal(station, station.length, aggregate.station, aggregate.station.length)) { + if (right.hash == aggregate.hash && Arrays.equals(right.station, aggregate.station)) { aggregate.merge(right); break; } @@ -309,7 +233,7 @@ public class CalculateAverage_artsiomkorzun { } public Aggregates sort() { - Arrays.parallelSort(aggregates, Comparator.nullsLast(Aggregate::compareTo)); + Arrays.sort(aggregates, Comparator.nullsLast(Aggregate::compareTo)); return this; } @@ -320,8 +244,7 @@ public class CalculateAverage_artsiomkorzun { for (Aggregate aggregate : oldAggregates) { if (aggregate != null) { - int hash = hash(aggregate.station, aggregate.station.length); - int index = hash & (aggregates.length - 1); + int index = aggregate.hash & (aggregates.length - 1); while (aggregates[index] != null) { index = (index + 1) & (aggregates.length - 1); @@ -331,29 +254,105 @@ public class CalculateAverage_artsiomkorzun { } } } + } - private static int hash(byte[] array, int length) { - int hash = 0; + private static class Aggregator extends Thread { - for (int i = 0; i < length; i++) { - hash = 71 * hash + array[i]; - } + private final AtomicInteger counter; + private final AtomicReference result; - return hash; + public Aggregator(AtomicInteger counter, AtomicReference result) { + super("aggregator"); + this.counter = counter; + this.result = result; } - private static boolean equal(byte[] left, int leftLength, byte[] right, int rightLength) { - if (leftLength != rightLength) { - return false; - } + @Override + public void run() { + Aggregates aggregates = new Aggregates(); + Row row = new Row(); - for (int i = 0; i < leftLength; i++) { - if (left[i] != right[i]) { - return false; + try (FileChannel channel = FileChannel.open(FILE, StandardOpenOption.READ)) { + for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) { + aggregate(channel, segment, aggregates, row); } } + catch (Throwable e) { + throw new RuntimeException(e); + } - return true; + while (!result.compareAndSet(null, aggregates)) { + Aggregates rights = result.getAndSet(null); + + if (rights != null) { + aggregates = merge(aggregates, rights); + } + } + } + + private static void aggregate(FileChannel channel, int segment, Aggregates aggregates, Row row) throws Exception { + long position = (long) SEGMENT_SIZE * segment; + int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, FILE_SIZE - position); + int limit = Math.min(SEGMENT_SIZE, size - 1); + + MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, position, size); + + if (position > 0) { + next(buffer); + } + + for (int offset = buffer.position(); offset <= limit;) { + offset = parse(buffer, row, offset); + aggregates.add(row); + } + } + + private static Aggregates merge(Aggregates lefts, Aggregates rights) { + if (rights.size() < lefts.size()) { + Aggregates temp = lefts; + lefts = rights; + rights = temp; + } + + rights.visit(lefts::merge); + return lefts; + } + + private static void next(ByteBuffer buffer) { + while (buffer.get() != '\n') { + // continue + } + } + + private static int parse(ByteBuffer buffer, Row row, int offset) { + byte[] station = row.station; + int length = 0; + int hash = 0; + + for (byte b; (b = buffer.get(offset++)) != ';';) { + station[length++] = b; + hash = 71 * hash + b; + } + + row.length = length; + row.hash = hash; + + int sign = 1; + + if (buffer.get(offset) == '-') { + sign = -1; + offset++; + } + + int value = buffer.get(offset++) - '0'; + + if (buffer.get(offset) != '.') { + value = 10 * value + buffer.get(offset++) - '0'; + } + + value = 10 * value + buffer.get(offset + 1) - '0'; + row.temperature = value * sign; + return offset + 3; } } }