kumarsaurav123 # Attempt 3 (#470)
* Use Memory Segment * Reduce Number of threads
This commit is contained in:
		@@ -16,6 +16,6 @@
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
JAVA_OPTS="-Xms6G -Xmx16G"
 | 
			
		||||
JAVA_OPTS="-Xms16G -Xmx32G --enable-preview"
 | 
			
		||||
 | 
			
		||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_kumarsaurav123
 | 
			
		||||
 
 | 
			
		||||
@@ -15,18 +15,20 @@
 | 
			
		||||
 */
 | 
			
		||||
package dev.morling.onebrc;
 | 
			
		||||
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.io.RandomAccessFile;
 | 
			
		||||
import java.nio.ByteBuffer;
 | 
			
		||||
import java.nio.ByteOrder;
 | 
			
		||||
import java.lang.foreign.Arena;
 | 
			
		||||
import java.lang.foreign.MemorySegment;
 | 
			
		||||
import java.lang.foreign.ValueLayout;
 | 
			
		||||
import java.nio.channels.FileChannel;
 | 
			
		||||
import java.nio.charset.StandardCharsets;
 | 
			
		||||
import java.nio.file.Paths;
 | 
			
		||||
import java.util.*;
 | 
			
		||||
import java.util.concurrent.ConcurrentSkipListMap;
 | 
			
		||||
import java.util.concurrent.ExecutorService;
 | 
			
		||||
import java.util.concurrent.Executors;
 | 
			
		||||
import java.util.concurrent.TimeUnit;
 | 
			
		||||
import java.util.concurrent.atomic.AtomicInteger;
 | 
			
		||||
import java.util.stream.Collector;
 | 
			
		||||
import java.util.stream.IntStream;
 | 
			
		||||
 | 
			
		||||
import static java.util.stream.Collectors.groupingBy;
 | 
			
		||||
 | 
			
		||||
@@ -40,7 +42,10 @@ public class CalculateAverage_kumarsaurav123 {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static record ResultRow(String station,double min, double mean, double max,double sum,double count) {
 | 
			
		||||
    private static record Pair(long start, int size) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static record ResultRow(String station, double min, double mean, double max, double sum, double count) {
 | 
			
		||||
        public String toString() {
 | 
			
		||||
            return round(min) + "/" + round(mean) + "/" + round(max);
 | 
			
		||||
        }
 | 
			
		||||
@@ -61,18 +66,13 @@ public class CalculateAverage_kumarsaurav123 {
 | 
			
		||||
        private String station;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static void main(String[] args) {
 | 
			
		||||
        HashMap<Byte, Integer> map = new HashMap<>();
 | 
			
		||||
        map.put((byte) 48, 0);
 | 
			
		||||
        map.put((byte) 49, 1);
 | 
			
		||||
        map.put((byte) 50, 2);
 | 
			
		||||
        map.put((byte) 51, 3);
 | 
			
		||||
        map.put((byte) 52, 4);
 | 
			
		||||
        map.put((byte) 53, 5);
 | 
			
		||||
        map.put((byte) 54, 6);
 | 
			
		||||
        map.put((byte) 55, 7);
 | 
			
		||||
        map.put((byte) 56, 8);
 | 
			
		||||
        map.put((byte) 57, 9);
 | 
			
		||||
    public static void main(String[] args) throws IOException {
 | 
			
		||||
        long start = System.currentTimeMillis();
 | 
			
		||||
        System.out.println(run(FILE));
 | 
			
		||||
        // System.out.println(System.currentTimeMillis() - start);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static String run(String filePath) throws IOException {
 | 
			
		||||
        Collector<ResultRow, MeasurementAggregator, ResultRow> collector2 = Collector.of(
 | 
			
		||||
                MeasurementAggregator::new,
 | 
			
		||||
                (a, m) -> {
 | 
			
		||||
@@ -91,7 +91,7 @@ public class CalculateAverage_kumarsaurav123 {
 | 
			
		||||
                    return res;
 | 
			
		||||
                },
 | 
			
		||||
                agg -> {
 | 
			
		||||
                    return new ResultRow(agg.station, agg.min, agg.sum / agg.count, agg.max, agg.sum, agg.count);
 | 
			
		||||
                    return new ResultRow(agg.station, agg.min, (Math.round(agg.sum * 10.0) / 10.0) / agg.count, agg.max, agg.sum, agg.count);
 | 
			
		||||
                });
 | 
			
		||||
        Collector<Measurement, MeasurementAggregator, ResultRow> collector = Collector.of(
 | 
			
		||||
                MeasurementAggregator::new,
 | 
			
		||||
@@ -114,38 +114,103 @@ public class CalculateAverage_kumarsaurav123 {
 | 
			
		||||
                agg -> {
 | 
			
		||||
                    return new ResultRow(agg.station, agg.min, agg.sum / agg.count, agg.max, agg.sum, agg.count);
 | 
			
		||||
                });
 | 
			
		||||
 | 
			
		||||
        long start = System.currentTimeMillis();
 | 
			
		||||
        long len = Paths.get(FILE).toFile().length();
 | 
			
		||||
        Map<Integer, List<byte[]>> leftOutsMap = new ConcurrentSkipListMap<>();
 | 
			
		||||
        int chunkSize = 1_0000_00;
 | 
			
		||||
        long proc = Math.max(1, (len / chunkSize));
 | 
			
		||||
        ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2 * 2 * 2);
 | 
			
		||||
        ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
 | 
			
		||||
        List<ResultRow> measurements = Collections.synchronizedList(new ArrayList<ResultRow>());
 | 
			
		||||
        IntStream.range(0, (int) proc)
 | 
			
		||||
                .mapToObj(i -> {
 | 
			
		||||
        int chunkSize = 1_0000_00;
 | 
			
		||||
        Map<Integer, List<byte[]>> leftOutsMap = new ConcurrentSkipListMap<>();
 | 
			
		||||
        RandomAccessFile file = new RandomAccessFile(filePath, "r");
 | 
			
		||||
        long filelength = file.length();
 | 
			
		||||
        AtomicInteger kk = new AtomicInteger();
 | 
			
		||||
        MemorySegment memorySegment = file.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, filelength, Arena.global());
 | 
			
		||||
        int nChunks = 1000;
 | 
			
		||||
 | 
			
		||||
        int pChunkSize = Math.min(Integer.MAX_VALUE, (int) (memorySegment.byteSize() / (1000 * 20)));
 | 
			
		||||
        if (pChunkSize < 100) {
 | 
			
		||||
            pChunkSize = (int) memorySegment.byteSize();
 | 
			
		||||
            nChunks = 1;
 | 
			
		||||
        }
 | 
			
		||||
        ArrayList<Pair> chunks = createStartAndEnd(pChunkSize, nChunks, memorySegment);
 | 
			
		||||
        chunks.stream()
 | 
			
		||||
                .map(p -> {
 | 
			
		||||
 | 
			
		||||
                    return createRunnable(memorySegment, p, collector, measurements, kk.getAndIncrement());
 | 
			
		||||
                })
 | 
			
		||||
                .forEach(executorService::submit);
 | 
			
		||||
        executorService.shutdown();
 | 
			
		||||
        try {
 | 
			
		||||
            executorService.awaitTermination(10, TimeUnit.MINUTES);
 | 
			
		||||
        }
 | 
			
		||||
        catch (InterruptedException e) {
 | 
			
		||||
            throw new RuntimeException(e);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Map<String, ResultRow> measurements2 = new TreeMap<>(measurements
 | 
			
		||||
                .stream()
 | 
			
		||||
                .parallel()
 | 
			
		||||
                .collect(groupingBy(ResultRow::station, collector2)));
 | 
			
		||||
        return measurements2.toString();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static ArrayList<Pair> createStartAndEnd(int chunksize, int nChunks, MemorySegment memorySegment) {
 | 
			
		||||
        ArrayList<Pair> startSizePairs = new ArrayList<>();
 | 
			
		||||
        byte eol = "\n".getBytes(StandardCharsets.UTF_8)[0];
 | 
			
		||||
        long start = 0;
 | 
			
		||||
        long end = -1;
 | 
			
		||||
        if (nChunks == 1) {
 | 
			
		||||
            startSizePairs.add(new Pair(0, chunksize));
 | 
			
		||||
            return startSizePairs;
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
            while (start < memorySegment.byteSize()) {
 | 
			
		||||
                start = end + 1;
 | 
			
		||||
                end = Math.min(memorySegment.byteSize() - 1, start + chunksize - 1);
 | 
			
		||||
                while (memorySegment.get(ValueLayout.JAVA_BYTE, end) != eol) {
 | 
			
		||||
                    end--;
 | 
			
		||||
 | 
			
		||||
                }
 | 
			
		||||
                startSizePairs.add(new Pair(start, (int) (end - start + 1)));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return startSizePairs;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static Runnable createRunnable(MemorySegment memorySegment, Pair p, Collector<Measurement, MeasurementAggregator, ResultRow> collector,
 | 
			
		||||
                                          List<ResultRow> measurements, int kk) {
 | 
			
		||||
        return new Runnable() {
 | 
			
		||||
            @Override
 | 
			
		||||
            public void run() {
 | 
			
		||||
                try {
 | 
			
		||||
                                RandomAccessFile file = new RandomAccessFile(FILE, "r");
 | 
			
		||||
                                byte[] allBytes2 = new byte[chunkSize];
 | 
			
		||||
                                file.seek((long) i * (long) chunkSize);
 | 
			
		||||
                                int l = file.read(allBytes2);
 | 
			
		||||
                    long start = System.currentTimeMillis();
 | 
			
		||||
 | 
			
		||||
                    byte[] allBytes2 = new byte[p.size];
 | 
			
		||||
                    MemorySegment lMemory = memorySegment.asSlice(p.start, p.size);
 | 
			
		||||
                    lMemory.asByteBuffer().get(allBytes2);
 | 
			
		||||
                    HashMap<Byte, Integer> map = new HashMap<>();
 | 
			
		||||
                    // Runtime runtime = Runtime.getRuntime();
 | 
			
		||||
                    // long memoryMax = runtime.maxMemory();
 | 
			
		||||
                    // long memoryUsed = runtime.totalMemory() - runtime.freeMemory();
 | 
			
		||||
                    // double memoryUsedPercent = (memoryUsed * 100.0) / memoryMax;
 | 
			
		||||
                    // System.out.println("memoryUsedPercent: " + memoryUsedPercent);
 | 
			
		||||
                    map.put((byte) 48, 0);
 | 
			
		||||
                    map.put((byte) 49, 1);
 | 
			
		||||
                    map.put((byte) 50, 2);
 | 
			
		||||
                    map.put((byte) 51, 3);
 | 
			
		||||
                    map.put((byte) 52, 4);
 | 
			
		||||
                    map.put((byte) 53, 5);
 | 
			
		||||
                    map.put((byte) 54, 6);
 | 
			
		||||
                    map.put((byte) 55, 7);
 | 
			
		||||
                    map.put((byte) 56, 8);
 | 
			
		||||
                    map.put((byte) 57, 9);
 | 
			
		||||
                    byte[] eol = "\n".getBytes(StandardCharsets.UTF_8);
 | 
			
		||||
                    byte[] sep = ";".getBytes(StandardCharsets.UTF_8);
 | 
			
		||||
 | 
			
		||||
                    List<Measurement> mst = new ArrayList<>();
 | 
			
		||||
                    int st = 0;
 | 
			
		||||
                                int cnt = 0;
 | 
			
		||||
                                ArrayList<byte[]> local = new ArrayList<>();
 | 
			
		||||
 | 
			
		||||
                                for (int i = 0; i < l; i++) {
 | 
			
		||||
                    for (int i = 0; i < allBytes2.length; i++) {
 | 
			
		||||
                        if (allBytes2[i] == eol[0]) {
 | 
			
		||||
                                        if (i != 0) {
 | 
			
		||||
                            byte[] s2 = new byte[i - st];
 | 
			
		||||
                            System.arraycopy(allBytes2, st, s2, 0, s2.length);
 | 
			
		||||
                                            if (cnt != 0) {
 | 
			
		||||
                            for (int j = 0; j < s2.length; j++) {
 | 
			
		||||
                                if (s2[j] == sep[0]) {
 | 
			
		||||
                                    byte[] city = new byte[j];
 | 
			
		||||
@@ -169,28 +234,14 @@ public class CalculateAverage_kumarsaurav123 {
 | 
			
		||||
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                                            }
 | 
			
		||||
                                            else {
 | 
			
		||||
                                                local.add(s2);
 | 
			
		||||
                                            }
 | 
			
		||||
 | 
			
		||||
                                        }
 | 
			
		||||
                                        cnt++;
 | 
			
		||||
                            st = i + 1;
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                                if (st < l) {
 | 
			
		||||
                                    byte[] s2 = new byte[allBytes2.length - st];
 | 
			
		||||
                                    System.arraycopy(allBytes2, st, s2, 0, s2.length);
 | 
			
		||||
                                    local.add(s2);
 | 
			
		||||
                                }
 | 
			
		||||
                                leftOutsMap.put(i, local);
 | 
			
		||||
                                allBytes2 = null;
 | 
			
		||||
                    // System.out.println("Task " + kk + "Completed in " + (System.currentTimeMillis() - start));
 | 
			
		||||
                    measurements.addAll(mst.stream()
 | 
			
		||||
                            .collect(groupingBy(Measurement::station, collector))
 | 
			
		||||
                            .values());
 | 
			
		||||
                                // System.out.println(measurements.size());
 | 
			
		||||
 | 
			
		||||
                }
 | 
			
		||||
                catch (Exception e) {
 | 
			
		||||
                    // throw new RuntimeException(e);
 | 
			
		||||
@@ -198,59 +249,5 @@ public class CalculateAverage_kumarsaurav123 {
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
                })
 | 
			
		||||
                .forEach(executor::submit);
 | 
			
		||||
        executor.shutdown();
 | 
			
		||||
 | 
			
		||||
        try {
 | 
			
		||||
            executor.awaitTermination(10, TimeUnit.MINUTES);
 | 
			
		||||
        }
 | 
			
		||||
        catch (InterruptedException e) {
 | 
			
		||||
            throw new RuntimeException(e);
 | 
			
		||||
        }
 | 
			
		||||
        Collection<Measurement> lMeasure = new ArrayList<>();
 | 
			
		||||
        List<byte[]> leftOuts = leftOutsMap.values()
 | 
			
		||||
                .stream()
 | 
			
		||||
                .flatMap(List::stream)
 | 
			
		||||
                .toList();
 | 
			
		||||
        int size = 0;
 | 
			
		||||
        for (int i = 0; i < leftOuts.size(); i++) {
 | 
			
		||||
            size = size + leftOuts.get(i).length;
 | 
			
		||||
        }
 | 
			
		||||
        byte[] allBytes = new byte[size];
 | 
			
		||||
        int pos = 0;
 | 
			
		||||
        for (int i = 0; i < leftOuts.size(); i++) {
 | 
			
		||||
            System.arraycopy(leftOuts.get(i), 0, allBytes, pos, leftOuts.get(i).length);
 | 
			
		||||
            pos = pos + leftOuts.get(i).length;
 | 
			
		||||
        }
 | 
			
		||||
        List<String> l = Arrays.asList(new String(allBytes).split(";"));
 | 
			
		||||
        List<Measurement> measurements1 = new ArrayList<>();
 | 
			
		||||
        String city = l.get(0);
 | 
			
		||||
        for (int i = 0; i < l.size() - 1; i++) {
 | 
			
		||||
            int sIndex = l.get(i + 1).indexOf('.') + 2;
 | 
			
		||||
 | 
			
		||||
            String tempp = l.get(i + 1).substring(0, sIndex);
 | 
			
		||||
 | 
			
		||||
            measurements1.add(new Measurement(city, Double.parseDouble(tempp)));
 | 
			
		||||
            city = l.get(i + 1).substring(sIndex);
 | 
			
		||||
        }
 | 
			
		||||
        measurements.addAll(measurements1.stream()
 | 
			
		||||
                .collect(groupingBy(Measurement::station, collector))
 | 
			
		||||
                .values());
 | 
			
		||||
        Map<String, ResultRow> measurements2 = new TreeMap<>(measurements
 | 
			
		||||
                .stream()
 | 
			
		||||
                .parallel()
 | 
			
		||||
                .collect(groupingBy(ResultRow::station, collector2)));
 | 
			
		||||
 | 
			
		||||
        // Read from bytes 1000 to 2000
 | 
			
		||||
        // Something like this
 | 
			
		||||
 | 
			
		||||
        //
 | 
			
		||||
        // Map<String, ResultRow> measurements = new TreeMap<>(Files.lines(Paths.get(FILE))
 | 
			
		||||
        // .map(l -> new Measurement(l.split(";")))
 | 
			
		||||
        // .collect(groupingBy(m -> m.station(), collector)));
 | 
			
		||||
 | 
			
		||||
        System.out.println(measurements2);
 | 
			
		||||
        // System.out.println(System.currentTimeMillis() - start);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user