From 8ba67cbc6d4d83432bac28e453efc7bf3a963c10 Mon Sep 17 00:00:00 2001 From: kumarsaurav123 Date: Sun, 21 Jan 2024 17:20:36 +0530 Subject: [PATCH] Use Array to store results instead of grouping by and custom class (#522) --- .../CalculateAverage_kumarsaurav123.java | 279 +++++++++--------- 1 file changed, 145 insertions(+), 134 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_kumarsaurav123.java b/src/main/java/dev/morling/onebrc/CalculateAverage_kumarsaurav123.java index f991f9f..87458d1 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_kumarsaurav123.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_kumarsaurav123.java @@ -23,132 +23,108 @@ import java.lang.foreign.ValueLayout; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.util.*; -import java.util.concurrent.ConcurrentSkipListMap; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; import java.util.stream.Collector; +import java.util.stream.Collectors; import static java.util.stream.Collectors.groupingBy; public class CalculateAverage_kumarsaurav123 { private static final String FILE = "./measurements.txt"; + private static AtomicInteger indexCount = new AtomicInteger(0); + private static final ReentrantLock lock = new ReentrantLock(); + private static final int MAX_UNIQUE_KEYS = 11000; + private static Map indexMap; - private static record Measurement(String station, double value) { - private Measurement(String[] parts) { - this(parts[0], Double.parseDouble(parts[1])); + private static record Store(double[] min, double[] max, double[] sum, + int[] count) { + + + private double round(double value) { + return Math.round(value * 10.0) / 10.0; + } + + @Override + public String toString() { + return new TreeMap<>(indexMap.entrySet() + .stream() + .map(e -> Map.entry(e.getKey().toString(), + round(min[e.getValue()]) + "/" + round((Math.round(sum[e.getValue()] * 10.0) / 10.0) / count[e.getValue()]) + "/" + round(max[e.getValue()]) + )) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))).toString(); } } private static record Pair(long start, int size) { } - private static record ResultRow(String station, double min, double mean, double max, double sum, double count) { - public String toString() { - return round(min) + "/" + round(mean) + "/" + round(max); - } - - private double round(double value) { - return Math.round(value * 10.0) / 10.0; - } - } - - ; - - private static class MeasurementAggregator { - private double min = Double.POSITIVE_INFINITY; - private double max = Double.NEGATIVE_INFINITY; - private double sum; - private long count; - - private String station; - } - - public static void main(String[] args) throws IOException { + public static void main(String[] args) throws IOException, ExecutionException, InterruptedException { long start = System.currentTimeMillis(); System.out.println(run(FILE)); - // System.out.println(System.currentTimeMillis() - start); } - public static String run(String filePath) throws IOException { - Collector collector2 = Collector.of( - MeasurementAggregator::new, - (a, m) -> { - a.min = Math.min(a.min, m.min); - a.max = Math.max(a.max, m.max); - a.sum += m.sum; - a.count += m.count; - }, - (agg1, agg2) -> { - var res = new MeasurementAggregator(); - res.min = Math.min(agg1.min, agg2.min); - res.max = Math.max(agg1.max, agg2.max); - res.sum = agg1.sum + agg2.sum; - res.count = agg1.count + agg2.count; - - return res; - }, - agg -> { - return new ResultRow(agg.station, agg.min, (Math.round(agg.sum * 10.0) / 10.0) / agg.count, agg.max, agg.sum, agg.count); - }); - Collector collector = Collector.of( - MeasurementAggregator::new, - (a, m) -> { - a.min = Math.min(a.min, m.value); - a.max = Math.max(a.max, m.value); - a.sum += m.value; - a.station = m.station; - a.count++; - }, - (agg1, agg2) -> { - var res = new MeasurementAggregator(); - res.min = Math.min(agg1.min, agg2.min); - res.max = Math.max(agg1.max, agg2.max); - res.sum = agg1.sum + agg2.sum; - res.count = agg1.count + agg2.count; - - return res; - }, - agg -> { - return new ResultRow(agg.station, agg.min, agg.sum / agg.count, agg.max, agg.sum, agg.count); - }); + public static String run(String filePath) throws IOException, InterruptedException, ExecutionException { + indexCount = new AtomicInteger(0); + indexMap = new HashMap<>(MAX_UNIQUE_KEYS); ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2); - List measurements = Collections.synchronizedList(new ArrayList()); - int chunkSize = 1_0000_00; + CompletionService completionService = new ExecutorCompletionService<>(executorService); Map> leftOutsMap = new ConcurrentSkipListMap<>(); RandomAccessFile file = new RandomAccessFile(filePath, "r"); long filelength = file.length(); AtomicInteger kk = new AtomicInteger(); - MemorySegment memorySegment = file.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, filelength, Arena.global()); + MemorySegment memorySegment = file.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, filelength, Arena.ofShared()); int nChunks = 1000; - int pChunkSize = Math.min(Integer.MAX_VALUE, (int) (memorySegment.byteSize() / (1000 * 20))); + int pChunkSize = Math.min(Integer.MAX_VALUE, (int) (memorySegment.byteSize() / (1000))); if (pChunkSize < 100) { pChunkSize = (int) memorySegment.byteSize(); nChunks = 1; } ArrayList chunks = createStartAndEnd(pChunkSize, nChunks, memorySegment); chunks.stream() + .parallel() .map(p -> { - return createRunnable(memorySegment, p, collector, measurements, kk.getAndIncrement()); + return createRunnable(memorySegment, p); }) - .forEach(executorService::submit); + .forEach(completionService::submit); executorService.shutdown(); - try { - executorService.awaitTermination(10, TimeUnit.MINUTES); - } - catch (InterruptedException e) { - throw new RuntimeException(e); + int i = 0; + double[] min = new double[MAX_UNIQUE_KEYS]; + double[] max = new double[MAX_UNIQUE_KEYS]; + double[] sum = new double[MAX_UNIQUE_KEYS]; + int[] count = new int[MAX_UNIQUE_KEYS]; + initArray(i, count, min, max, sum); + i = 0; + final Store cureentStore = new Store(min, max, sum, count); + while (i < chunks.size()) { + Store newStore = completionService.take().get(); + Map reverseMap = indexMap.entrySet() + .stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + reverseMap.forEach((key, value) -> { + cureentStore.sum[key] += newStore.sum[key]; + cureentStore.count[key] += newStore.count[key]; + cureentStore.min[key] = Math.min(cureentStore.min[key], + newStore.min[key]); + cureentStore.max[key] = Math.max(cureentStore.max[key], + newStore.max[key]); + }); + i++; } - Map measurements2 = new TreeMap<>(measurements - .stream() - .parallel() - .collect(groupingBy(ResultRow::station, collector2))); - return measurements2.toString(); + return cureentStore.toString(); + } + + private static void initArray(int i, int[] count, double[] min, double[] max, double[] sum) { + for (; i < count.length; i++) { + min[i] = Double.POSITIVE_INFINITY; + max[i] = Double.NEGATIVE_INFINITY; + sum[i] = 0.0d; + count[i] = 0; + } } private static ArrayList createStartAndEnd(int chunksize, int nChunks, MemorySegment memorySegment) { @@ -174,41 +150,30 @@ public class CalculateAverage_kumarsaurav123 { return startSizePairs; } - public static Runnable createRunnable(MemorySegment memorySegment, Pair p, Collector collector, - List measurements, int kk) { - return new Runnable() { + public static Callable createRunnable(MemorySegment memorySegment, Pair p) { + return new Callable() { @Override - public void run() { + public Store call() { try { - long start = System.currentTimeMillis(); + double[] min = new double[MAX_UNIQUE_KEYS]; + double[] max = new double[MAX_UNIQUE_KEYS]; + double[] sum = new double[MAX_UNIQUE_KEYS]; + int[] count = new int[MAX_UNIQUE_KEYS]; + for (int i = 0; i < count.length; i++) { + min[i] = Double.POSITIVE_INFINITY; + max[i] = Double.NEGATIVE_INFINITY; + sum[i] = 0.0d; + count[i] = 0; + } - byte[] allBytes2 = new byte[p.size]; - MemorySegment lMemory = memorySegment.asSlice(p.start, p.size); - lMemory.asByteBuffer().get(allBytes2); - HashMap map = new HashMap<>(); - // Runtime runtime = Runtime.getRuntime(); - // long memoryMax = runtime.maxMemory(); - // long memoryUsed = runtime.totalMemory() - runtime.freeMemory(); - // double memoryUsedPercent = (memoryUsed * 100.0) / memoryMax; - // System.out.println("memoryUsedPercent: " + memoryUsedPercent); - map.put((byte) 48, 0); - map.put((byte) 49, 1); - map.put((byte) 50, 2); - map.put((byte) 51, 3); - map.put((byte) 52, 4); - map.put((byte) 53, 5); - map.put((byte) 54, 6); - map.put((byte) 55, 7); - map.put((byte) 56, 8); - map.put((byte) 57, 9); + byte[] allBytes2 = memorySegment.asSlice(p.start, p.size).toArray(ValueLayout.JAVA_BYTE); byte[] eol = "\n".getBytes(StandardCharsets.UTF_8); byte[] sep = ";".getBytes(StandardCharsets.UTF_8); - List mst = new ArrayList<>(); int st = 0; - for (int i = 0; i < allBytes2.length; i++) { if (allBytes2[i] == eol[0]) { + ; byte[] s2 = new byte[i - st]; System.arraycopy(allBytes2, st, s2, 0, s2.length); for (int j = 0; j < s2.length; j++) { @@ -217,37 +182,83 @@ public class CalculateAverage_kumarsaurav123 { byte[] value = new byte[s2.length - j - 1]; System.arraycopy(s2, 0, city, 0, city.length); System.arraycopy(s2, city.length + 1, value, 0, value.length); - double d = 0.0; - int s = -1; - for (int k = value.length - 1; k >= 0; k--) { - if (value[k] == 45) { - d = d * -1; - } - else if (value[k] == 46) { - } - else { - d = d + map.get(value[k]).intValue() * Math.pow(10, s); - s++; - } - } - mst.add(new Measurement(new String(city), d)); + double d = getaDouble(value); + StringHolder citys = new StringHolder(city); + Integer index = indexMap.get(citys); + if (Objects.isNull(index)) { + lock.lock(); + if (Objects.isNull(indexMap.get(citys))) { + index = indexCount.getAndIncrement(); + indexMap.putIfAbsent(citys, index); + } + index = indexMap.get(citys); + lock.unlock(); + } + + count[index] = count[index] + 1; + max[index] = Math.max(max[index], d); + min[index] = Math.min(min[index], d); + sum[index] = Double.sum(sum[index], d); + break; } } st = i + 1; } } - // System.out.println("Task " + kk + "Completed in " + (System.currentTimeMillis() - start)); - measurements.addAll(mst.stream() - .collect(groupingBy(Measurement::station, collector)) - .values()); - + // System.out.println("Task " + kk + "Completed in " + (System.nanoTime() - start)); + return new Store(min, max, sum, count); } catch (Exception e) { // throw new RuntimeException(e); - System.out.println(""); + throw e; } } }; } + + private static double getaDouble(byte[] value) { + double d = 0.0; + int s = -1; + for (int k = value.length - 1; k >= 0; k--) { + if (value[k] == 45) { + d = d * -1; + } + else if (value[k] == 46) { + } + else { + d = d + (((int) value[k]) - 48) * Math.pow(10, s); + s++; + } + } + return d; + } + + static class StringHolder implements Comparable { + byte[] bytes; + + public StringHolder(byte[] bytes) { + this.bytes = bytes; + } + + @Override + public String toString() { + return new String(this.bytes); + } + + @Override + public int hashCode() { + return Arrays.hashCode(this.bytes); + } + + @Override + public boolean equals(Object obj) { + return Arrays.equals(this.bytes, ((StringHolder) obj).bytes); + } + + @Override + public int compareTo(StringHolder o) { + return new String(this.bytes).compareTo(new String(o.bytes)); + } + } }