improved artsiomkorzun solution
This commit is contained in:
		
				
					committed by
					
						 Gunnar Morling
						Gunnar Morling
					
				
			
			
				
	
			
			
			
						parent
						
							a53549ae50
						
					
				
				
					commit
					cec579b506
				
			| @@ -17,4 +17,6 @@ | ||||
|  | ||||
|  | ||||
| JAVA_OPTS="-XX:+UseParallelGC" | ||||
| source "$HOME/.sdkman/bin/sdkman-init.sh" | ||||
| sdk use java 21.0.1-graal 1>&2 | ||||
| time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_artsiomkorzun | ||||
|   | ||||
| @@ -24,70 +24,51 @@ import java.nio.file.Path; | ||||
| import java.nio.file.StandardOpenOption; | ||||
| import java.util.Arrays; | ||||
| import java.util.Comparator; | ||||
| import java.util.concurrent.atomic.AtomicInteger; | ||||
| import java.util.concurrent.atomic.AtomicReference; | ||||
| import java.util.function.Consumer; | ||||
| import java.util.stream.IntStream; | ||||
|  | ||||
| public class CalculateAverage_artsiomkorzun { | ||||
|  | ||||
|     private static final Path FILE = Path.of("./measurements.txt"); | ||||
|     private static final long FILE_SIZE = size(FILE); | ||||
|  | ||||
|     private static final int PARALLELISM = Runtime.getRuntime().availableProcessors(); | ||||
|     private static final int SEGMENT_SIZE = 16 * 1024 * 1024; | ||||
|     private static final int SEGMENT_COUNT = (int) ((FILE_SIZE + SEGMENT_SIZE - 1) / SEGMENT_SIZE); | ||||
|     private static final int SEGMENT_OVERLAP = 1024; | ||||
|  | ||||
|     public static void main(String[] args) throws Exception { | ||||
|         /* | ||||
|          * for (int i = 0; i < 10; i++) { | ||||
|          * long start = System.currentTimeMillis(); | ||||
|          * execute(); | ||||
|          * long end = System.currentTimeMillis(); | ||||
|          * System.err.println("Time: " + (end - start)); | ||||
|          * } | ||||
|          */ | ||||
|         // for (int i = 0; i < 10; i++) { | ||||
|         // long start = System.currentTimeMillis(); | ||||
|         // execute(); | ||||
|         // long end = System.currentTimeMillis(); | ||||
|         // System.err.println("Time: " + (end - start)); | ||||
|         // } | ||||
|  | ||||
|         execute(); | ||||
|     } | ||||
|  | ||||
|     private static void execute() { | ||||
|         Aggregates aggregates = IntStream.range(0, SEGMENT_COUNT) | ||||
|                 .parallel() | ||||
|                 .mapToObj(CalculateAverage_artsiomkorzun::aggregate) | ||||
|                 .reduce(new Aggregates(), CalculateAverage_artsiomkorzun::merge) | ||||
|                 .sort(); | ||||
|     private static void execute() throws Exception { | ||||
|         AtomicInteger counter = new AtomicInteger(); | ||||
|         AtomicReference<Aggregates> result = new AtomicReference<>(); | ||||
|         Aggregator[] aggregators = new Aggregator[PARALLELISM]; | ||||
|  | ||||
|         for (int i = 0; i < aggregators.length; i++) { | ||||
|             aggregators[i] = new Aggregator(counter, result); | ||||
|             aggregators[i].start(); | ||||
|         } | ||||
|  | ||||
|         for (int i = 0; i < aggregators.length; i++) { | ||||
|             aggregators[i].join(); | ||||
|         } | ||||
|  | ||||
|         Aggregates aggregates = result.get(); | ||||
|         aggregates.sort(); | ||||
|  | ||||
|         print(aggregates); | ||||
|     } | ||||
|  | ||||
|     private static Aggregates aggregate(int segment) { | ||||
|         long position = (long) SEGMENT_SIZE * segment; | ||||
|         int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, FILE_SIZE - position); | ||||
|         int limit = Math.min(SEGMENT_SIZE, size - 1); | ||||
|  | ||||
|         MappedByteBuffer buffer = map(position, size); // leaking until gc | ||||
|  | ||||
|         if (position > 0) { | ||||
|             next(buffer); | ||||
|         } | ||||
|  | ||||
|         Aggregates aggregates = new Aggregates(); | ||||
|         Row row = new Row(); | ||||
|  | ||||
|         while (buffer.position() <= limit) { | ||||
|             parse(buffer, row); | ||||
|             aggregates.add(row); | ||||
|         } | ||||
|  | ||||
|         return aggregates; | ||||
|     } | ||||
|  | ||||
|     private static Aggregates merge(Aggregates lefts, Aggregates rights) { | ||||
|         Aggregates to = (lefts.size() < rights.size()) ? rights : lefts; | ||||
|         Aggregates from = (lefts.size() < rights.size()) ? lefts : rights; | ||||
|         from.visit(to::merge); | ||||
|         return to; | ||||
|     } | ||||
|  | ||||
|     private static void print(Aggregates aggregates) { | ||||
|         StringBuilder builder = new StringBuilder(aggregates.size() * 15 + 32); | ||||
|         builder.append("{"); | ||||
| @@ -111,62 +92,11 @@ public class CalculateAverage_artsiomkorzun { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static MappedByteBuffer map(long position, int size) { | ||||
|         try (FileChannel channel = FileChannel.open(FILE, StandardOpenOption.READ)) { | ||||
|             return channel.map(FileChannel.MapMode.READ_ONLY, position, size); // leaking until gc | ||||
|         } | ||||
|         catch (Throwable e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static void next(ByteBuffer buffer) { | ||||
|         while (buffer.get() != '\n') { | ||||
|             // continue | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static void parse(ByteBuffer buffer, Row row) { | ||||
|         int index = 0; | ||||
|         byte b; | ||||
|  | ||||
|         while ((b = buffer.get()) != ';') { | ||||
|             row.station[index++] = b; | ||||
|         } | ||||
|  | ||||
|         row.length = index; | ||||
|  | ||||
|         double value = 0; | ||||
|         double multiplier = 1; | ||||
|  | ||||
|         b = buffer.get(); | ||||
|         if (b == '-') { | ||||
|             multiplier = -1; | ||||
|         } | ||||
|         else { | ||||
|             assert b >= '0' && b <= '9'; | ||||
|             value = b - '0'; | ||||
|         } | ||||
|  | ||||
|         while ((b = buffer.get()) != '.') { | ||||
|             assert b >= '0' && b <= '9'; | ||||
|             value = 10 * value + (b - '0'); | ||||
|         } | ||||
|  | ||||
|         b = buffer.get(); | ||||
|         assert b >= '0' && b <= '9'; | ||||
|         value = 10 * value + (b - '0'); | ||||
|  | ||||
|         b = buffer.get(); | ||||
|         assert b == '\n'; | ||||
|  | ||||
|         row.temperature = value * multiplier; | ||||
|     } | ||||
|  | ||||
|     private static class Row { | ||||
|         final byte[] station = new byte[256]; | ||||
|         int length; | ||||
|         double temperature; | ||||
|         int hash; | ||||
|         int temperature; | ||||
|  | ||||
|         @Override | ||||
|         public String toString() { | ||||
| @@ -176,23 +106,25 @@ public class CalculateAverage_artsiomkorzun { | ||||
|  | ||||
|     private static class Aggregate implements Comparable<Aggregate> { | ||||
|         final byte[] station; | ||||
|         double min; | ||||
|         double max; | ||||
|         double sum; | ||||
|         double count; | ||||
|         final int hash; | ||||
|         int min; | ||||
|         int max; | ||||
|         long sum; | ||||
|         int count; | ||||
|  | ||||
|         public Aggregate(byte[] station, int length, double temperature) { | ||||
|             this.station = Arrays.copyOf(station, length); | ||||
|             this.min = temperature; | ||||
|             this.max = temperature; | ||||
|             this.sum = temperature; | ||||
|         public Aggregate(Row row) { | ||||
|             this.station = Arrays.copyOf(row.station, row.length); | ||||
|             this.hash = row.hash; | ||||
|             this.min = row.temperature; | ||||
|             this.max = row.temperature; | ||||
|             this.sum = row.temperature; | ||||
|             this.count = 1; | ||||
|         } | ||||
|  | ||||
|         public void add(double temperature) { | ||||
|             min = Math.min(min, temperature); | ||||
|             max = Math.max(max, temperature); | ||||
|             sum += temperature; | ||||
|         public void add(Row row) { | ||||
|             min = Math.min(min, row.temperature); | ||||
|             max = Math.max(max, row.temperature); | ||||
|             sum += row.temperature; | ||||
|             count++; | ||||
|         } | ||||
|  | ||||
| @@ -223,7 +155,7 @@ public class CalculateAverage_artsiomkorzun { | ||||
|  | ||||
|         @Override | ||||
|         public String toString() { | ||||
|             return new String(station) + "=" + round(min) + "/" + round(sum / count) + "/" + round(max); | ||||
|             return new String(station) + "=" + round(min) + "/" + round(1.0 * sum / count) + "/" + round(max); | ||||
|         } | ||||
|  | ||||
|         private static double round(double v) { | ||||
| @@ -255,26 +187,21 @@ public class CalculateAverage_artsiomkorzun { | ||||
|         } | ||||
|  | ||||
|         public void add(Row row) { | ||||
|             byte[] station = row.station; | ||||
|             int length = row.length; | ||||
|             double temperature = row.temperature; | ||||
|  | ||||
|             int hash = hash(station, length); | ||||
|             int index = hash & (aggregates.length - 1); | ||||
|             int index = row.hash & (aggregates.length - 1); | ||||
|  | ||||
|             while (true) { | ||||
|                 Aggregate aggregate = aggregates[index]; | ||||
|  | ||||
|                 if (aggregate == null) { | ||||
|                     aggregates[index] = new Aggregate(station, length, temperature); | ||||
|                     aggregates[index] = new Aggregate(row); | ||||
|                     if (++size >= limit) { | ||||
|                         grow(); | ||||
|                     } | ||||
|                     break; | ||||
|                 } | ||||
|  | ||||
|                 if (equal(station, length, aggregate.station, aggregate.station.length)) { | ||||
|                     aggregate.add(temperature); | ||||
|                 if (row.hash == aggregate.hash && Arrays.equals(row.station, 0, row.length, aggregate.station, 0, aggregate.station.length)) { | ||||
|                     aggregate.add(row); | ||||
|                     break; | ||||
|                 } | ||||
|  | ||||
| @@ -283,10 +210,7 @@ public class CalculateAverage_artsiomkorzun { | ||||
|         } | ||||
|  | ||||
|         public void merge(Aggregate right) { | ||||
|             byte[] station = right.station; | ||||
|  | ||||
|             int hash = hash(station, station.length); | ||||
|             int index = hash & (aggregates.length - 1); | ||||
|             int index = right.hash & (aggregates.length - 1); | ||||
|  | ||||
|             while (true) { | ||||
|                 Aggregate aggregate = aggregates[index]; | ||||
| @@ -299,7 +223,7 @@ public class CalculateAverage_artsiomkorzun { | ||||
|                     break; | ||||
|                 } | ||||
|  | ||||
|                 if (equal(station, station.length, aggregate.station, aggregate.station.length)) { | ||||
|                 if (right.hash == aggregate.hash && Arrays.equals(right.station, aggregate.station)) { | ||||
|                     aggregate.merge(right); | ||||
|                     break; | ||||
|                 } | ||||
| @@ -309,7 +233,7 @@ public class CalculateAverage_artsiomkorzun { | ||||
|         } | ||||
|  | ||||
|         public Aggregates sort() { | ||||
|             Arrays.parallelSort(aggregates, Comparator.nullsLast(Aggregate::compareTo)); | ||||
|             Arrays.sort(aggregates, Comparator.nullsLast(Aggregate::compareTo)); | ||||
|             return this; | ||||
|         } | ||||
|  | ||||
| @@ -320,8 +244,7 @@ public class CalculateAverage_artsiomkorzun { | ||||
|  | ||||
|             for (Aggregate aggregate : oldAggregates) { | ||||
|                 if (aggregate != null) { | ||||
|                     int hash = hash(aggregate.station, aggregate.station.length); | ||||
|                     int index = hash & (aggregates.length - 1); | ||||
|                     int index = aggregate.hash & (aggregates.length - 1); | ||||
|  | ||||
|                     while (aggregates[index] != null) { | ||||
|                         index = (index + 1) & (aggregates.length - 1); | ||||
| @@ -331,29 +254,105 @@ public class CalculateAverage_artsiomkorzun { | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|         private static int hash(byte[] array, int length) { | ||||
|             int hash = 0; | ||||
|     private static class Aggregator extends Thread { | ||||
|  | ||||
|             for (int i = 0; i < length; i++) { | ||||
|                 hash = 71 * hash + array[i]; | ||||
|             } | ||||
|         private final AtomicInteger counter; | ||||
|         private final AtomicReference<Aggregates> result; | ||||
|  | ||||
|             return hash; | ||||
|         public Aggregator(AtomicInteger counter, AtomicReference<Aggregates> result) { | ||||
|             super("aggregator"); | ||||
|             this.counter = counter; | ||||
|             this.result = result; | ||||
|         } | ||||
|  | ||||
|         private static boolean equal(byte[] left, int leftLength, byte[] right, int rightLength) { | ||||
|             if (leftLength != rightLength) { | ||||
|                 return false; | ||||
|             } | ||||
|         @Override | ||||
|         public void run() { | ||||
|             Aggregates aggregates = new Aggregates(); | ||||
|             Row row = new Row(); | ||||
|  | ||||
|             for (int i = 0; i < leftLength; i++) { | ||||
|                 if (left[i] != right[i]) { | ||||
|                     return false; | ||||
|             try (FileChannel channel = FileChannel.open(FILE, StandardOpenOption.READ)) { | ||||
|                 for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) { | ||||
|                     aggregate(channel, segment, aggregates, row); | ||||
|                 } | ||||
|             } | ||||
|             catch (Throwable e) { | ||||
|                 throw new RuntimeException(e); | ||||
|             } | ||||
|  | ||||
|             return true; | ||||
|             while (!result.compareAndSet(null, aggregates)) { | ||||
|                 Aggregates rights = result.getAndSet(null); | ||||
|  | ||||
|                 if (rights != null) { | ||||
|                     aggregates = merge(aggregates, rights); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private static void aggregate(FileChannel channel, int segment, Aggregates aggregates, Row row) throws Exception { | ||||
|             long position = (long) SEGMENT_SIZE * segment; | ||||
|             int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, FILE_SIZE - position); | ||||
|             int limit = Math.min(SEGMENT_SIZE, size - 1); | ||||
|  | ||||
|             MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, position, size); | ||||
|  | ||||
|             if (position > 0) { | ||||
|                 next(buffer); | ||||
|             } | ||||
|  | ||||
|             for (int offset = buffer.position(); offset <= limit;) { | ||||
|                 offset = parse(buffer, row, offset); | ||||
|                 aggregates.add(row); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private static Aggregates merge(Aggregates lefts, Aggregates rights) { | ||||
|             if (rights.size() < lefts.size()) { | ||||
|                 Aggregates temp = lefts; | ||||
|                 lefts = rights; | ||||
|                 rights = temp; | ||||
|             } | ||||
|  | ||||
|             rights.visit(lefts::merge); | ||||
|             return lefts; | ||||
|         } | ||||
|  | ||||
|         private static void next(ByteBuffer buffer) { | ||||
|             while (buffer.get() != '\n') { | ||||
|                 // continue | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private static int parse(ByteBuffer buffer, Row row, int offset) { | ||||
|             byte[] station = row.station; | ||||
|             int length = 0; | ||||
|             int hash = 0; | ||||
|  | ||||
|             for (byte b; (b = buffer.get(offset++)) != ';';) { | ||||
|                 station[length++] = b; | ||||
|                 hash = 71 * hash + b; | ||||
|             } | ||||
|  | ||||
|             row.length = length; | ||||
|             row.hash = hash; | ||||
|  | ||||
|             int sign = 1; | ||||
|  | ||||
|             if (buffer.get(offset) == '-') { | ||||
|                 sign = -1; | ||||
|                 offset++; | ||||
|             } | ||||
|  | ||||
|             int value = buffer.get(offset++) - '0'; | ||||
|  | ||||
|             if (buffer.get(offset) != '.') { | ||||
|                 value = 10 * value + buffer.get(offset++) - '0'; | ||||
|             } | ||||
|  | ||||
|             value = 10 * value + buffer.get(offset + 1) - '0'; | ||||
|             row.temperature = value * sign; | ||||
|             return offset + 3; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user