kumarsaurav123 # Attempt 3 (#470)

* Use Memory Segment

* Reduce Number of threads
This commit is contained in:
kumarsaurav123 2024-01-20 02:05:25 +05:30 committed by GitHub
parent 836f0805ad
commit f6bcaae4b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 140 deletions

View File

@ -16,6 +16,6 @@
#
JAVA_OPTS="-Xms6G -Xmx16G"
JAVA_OPTS="-Xms16G -Xmx32G --enable-preview"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_kumarsaurav123

View File

@ -15,18 +15,20 @@
*/
package dev.morling.onebrc;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collector;
import java.util.stream.IntStream;
import static java.util.stream.Collectors.groupingBy;
@ -40,7 +42,10 @@ public class CalculateAverage_kumarsaurav123 {
}
}
private static record ResultRow(String station,double min, double mean, double max,double sum,double count) {
private static record Pair(long start, int size) {
}
private static record ResultRow(String station, double min, double mean, double max, double sum, double count) {
public String toString() {
return round(min) + "/" + round(mean) + "/" + round(max);
}
@ -61,18 +66,13 @@ public class CalculateAverage_kumarsaurav123 {
private String station;
}
public static void main(String[] args) {
HashMap<Byte, Integer> map = new HashMap<>();
map.put((byte) 48, 0);
map.put((byte) 49, 1);
map.put((byte) 50, 2);
map.put((byte) 51, 3);
map.put((byte) 52, 4);
map.put((byte) 53, 5);
map.put((byte) 54, 6);
map.put((byte) 55, 7);
map.put((byte) 56, 8);
map.put((byte) 57, 9);
public static void main(String[] args) throws IOException {
long start = System.currentTimeMillis();
System.out.println(run(FILE));
// System.out.println(System.currentTimeMillis() - start);
}
public static String run(String filePath) throws IOException {
Collector<ResultRow, MeasurementAggregator, ResultRow> collector2 = Collector.of(
MeasurementAggregator::new,
(a, m) -> {
@ -91,7 +91,7 @@ public class CalculateAverage_kumarsaurav123 {
return res;
},
agg -> {
return new ResultRow(agg.station, agg.min, agg.sum / agg.count, agg.max, agg.sum, agg.count);
return new ResultRow(agg.station, agg.min, (Math.round(agg.sum * 10.0) / 10.0) / agg.count, agg.max, agg.sum, agg.count);
});
Collector<Measurement, MeasurementAggregator, ResultRow> collector = Collector.of(
MeasurementAggregator::new,
@ -114,143 +114,140 @@ public class CalculateAverage_kumarsaurav123 {
agg -> {
return new ResultRow(agg.station, agg.min, agg.sum / agg.count, agg.max, agg.sum, agg.count);
});
long start = System.currentTimeMillis();
long len = Paths.get(FILE).toFile().length();
Map<Integer, List<byte[]>> leftOutsMap = new ConcurrentSkipListMap<>();
int chunkSize = 1_0000_00;
long proc = Math.max(1, (len / chunkSize));
ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2 * 2 * 2);
ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
List<ResultRow> measurements = Collections.synchronizedList(new ArrayList<ResultRow>());
IntStream.range(0, (int) proc)
.mapToObj(i -> {
return new Runnable() {
@Override
public void run() {
try {
RandomAccessFile file = new RandomAccessFile(FILE, "r");
byte[] allBytes2 = new byte[chunkSize];
file.seek((long) i * (long) chunkSize);
int l = file.read(allBytes2);
byte[] eol = "\n".getBytes(StandardCharsets.UTF_8);
byte[] sep = ";".getBytes(StandardCharsets.UTF_8);
int chunkSize = 1_0000_00;
Map<Integer, List<byte[]>> leftOutsMap = new ConcurrentSkipListMap<>();
RandomAccessFile file = new RandomAccessFile(filePath, "r");
long filelength = file.length();
AtomicInteger kk = new AtomicInteger();
MemorySegment memorySegment = file.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, filelength, Arena.global());
int nChunks = 1000;
List<Measurement> mst = new ArrayList<>();
int st = 0;
int cnt = 0;
ArrayList<byte[]> local = new ArrayList<>();
int pChunkSize = Math.min(Integer.MAX_VALUE, (int) (memorySegment.byteSize() / (1000 * 20)));
if (pChunkSize < 100) {
pChunkSize = (int) memorySegment.byteSize();
nChunks = 1;
}
ArrayList<Pair> chunks = createStartAndEnd(pChunkSize, nChunks, memorySegment);
chunks.stream()
.map(p -> {
for (int i = 0; i < l; i++) {
if (allBytes2[i] == eol[0]) {
if (i != 0) {
byte[] s2 = new byte[i - st];
System.arraycopy(allBytes2, st, s2, 0, s2.length);
if (cnt != 0) {
for (int j = 0; j < s2.length; j++) {
if (s2[j] == sep[0]) {
byte[] city = new byte[j];
byte[] value = new byte[s2.length - j - 1];
System.arraycopy(s2, 0, city, 0, city.length);
System.arraycopy(s2, city.length + 1, value, 0, value.length);
double d = 0.0;
int s = -1;
for (int k = value.length - 1; k >= 0; k--) {
if (value[k] == 45) {
d = d * -1;
}
else if (value[k] == 46) {
}
else {
d = d + map.get(value[k]).intValue() * Math.pow(10, s);
s++;
}
}
mst.add(new Measurement(new String(city), d));
}
}
}
else {
local.add(s2);
}
}
cnt++;
st = i + 1;
}
}
if (st < l) {
byte[] s2 = new byte[allBytes2.length - st];
System.arraycopy(allBytes2, st, s2, 0, s2.length);
local.add(s2);
}
leftOutsMap.put(i, local);
allBytes2 = null;
measurements.addAll(mst.stream()
.collect(groupingBy(Measurement::station, collector))
.values());
// System.out.println(measurements.size());
}
catch (Exception e) {
// throw new RuntimeException(e);
System.out.println("");
}
}
};
return createRunnable(memorySegment, p, collector, measurements, kk.getAndIncrement());
})
.forEach(executor::submit);
executor.shutdown();
.forEach(executorService::submit);
executorService.shutdown();
try {
executor.awaitTermination(10, TimeUnit.MINUTES);
executorService.awaitTermination(10, TimeUnit.MINUTES);
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
Collection<Measurement> lMeasure = new ArrayList<>();
List<byte[]> leftOuts = leftOutsMap.values()
.stream()
.flatMap(List::stream)
.toList();
int size = 0;
for (int i = 0; i < leftOuts.size(); i++) {
size = size + leftOuts.get(i).length;
}
byte[] allBytes = new byte[size];
int pos = 0;
for (int i = 0; i < leftOuts.size(); i++) {
System.arraycopy(leftOuts.get(i), 0, allBytes, pos, leftOuts.get(i).length);
pos = pos + leftOuts.get(i).length;
}
List<String> l = Arrays.asList(new String(allBytes).split(";"));
List<Measurement> measurements1 = new ArrayList<>();
String city = l.get(0);
for (int i = 0; i < l.size() - 1; i++) {
int sIndex = l.get(i + 1).indexOf('.') + 2;
String tempp = l.get(i + 1).substring(0, sIndex);
measurements1.add(new Measurement(city, Double.parseDouble(tempp)));
city = l.get(i + 1).substring(sIndex);
}
measurements.addAll(measurements1.stream()
.collect(groupingBy(Measurement::station, collector))
.values());
Map<String, ResultRow> measurements2 = new TreeMap<>(measurements
.stream()
.parallel()
.collect(groupingBy(ResultRow::station, collector2)));
return measurements2.toString();
}
// Read from bytes 1000 to 2000
// Something like this
private static ArrayList<Pair> createStartAndEnd(int chunksize, int nChunks, MemorySegment memorySegment) {
ArrayList<Pair> startSizePairs = new ArrayList<>();
byte eol = "\n".getBytes(StandardCharsets.UTF_8)[0];
long start = 0;
long end = -1;
if (nChunks == 1) {
startSizePairs.add(new Pair(0, chunksize));
return startSizePairs;
}
else {
while (start < memorySegment.byteSize()) {
start = end + 1;
end = Math.min(memorySegment.byteSize() - 1, start + chunksize - 1);
while (memorySegment.get(ValueLayout.JAVA_BYTE, end) != eol) {
end--;
//
// Map<String, ResultRow> measurements = new TreeMap<>(Files.lines(Paths.get(FILE))
// .map(l -> new Measurement(l.split(";")))
// .collect(groupingBy(m -> m.station(), collector)));
}
startSizePairs.add(new Pair(start, (int) (end - start + 1)));
}
}
return startSizePairs;
}
System.out.println(measurements2);
// System.out.println(System.currentTimeMillis() - start);
public static Runnable createRunnable(MemorySegment memorySegment, Pair p, Collector<Measurement, MeasurementAggregator, ResultRow> collector,
List<ResultRow> measurements, int kk) {
return new Runnable() {
@Override
public void run() {
try {
long start = System.currentTimeMillis();
byte[] allBytes2 = new byte[p.size];
MemorySegment lMemory = memorySegment.asSlice(p.start, p.size);
lMemory.asByteBuffer().get(allBytes2);
HashMap<Byte, Integer> map = new HashMap<>();
// Runtime runtime = Runtime.getRuntime();
// long memoryMax = runtime.maxMemory();
// long memoryUsed = runtime.totalMemory() - runtime.freeMemory();
// double memoryUsedPercent = (memoryUsed * 100.0) / memoryMax;
// System.out.println("memoryUsedPercent: " + memoryUsedPercent);
map.put((byte) 48, 0);
map.put((byte) 49, 1);
map.put((byte) 50, 2);
map.put((byte) 51, 3);
map.put((byte) 52, 4);
map.put((byte) 53, 5);
map.put((byte) 54, 6);
map.put((byte) 55, 7);
map.put((byte) 56, 8);
map.put((byte) 57, 9);
byte[] eol = "\n".getBytes(StandardCharsets.UTF_8);
byte[] sep = ";".getBytes(StandardCharsets.UTF_8);
List<Measurement> mst = new ArrayList<>();
int st = 0;
for (int i = 0; i < allBytes2.length; i++) {
if (allBytes2[i] == eol[0]) {
byte[] s2 = new byte[i - st];
System.arraycopy(allBytes2, st, s2, 0, s2.length);
for (int j = 0; j < s2.length; j++) {
if (s2[j] == sep[0]) {
byte[] city = new byte[j];
byte[] value = new byte[s2.length - j - 1];
System.arraycopy(s2, 0, city, 0, city.length);
System.arraycopy(s2, city.length + 1, value, 0, value.length);
double d = 0.0;
int s = -1;
for (int k = value.length - 1; k >= 0; k--) {
if (value[k] == 45) {
d = d * -1;
}
else if (value[k] == 46) {
}
else {
d = d + map.get(value[k]).intValue() * Math.pow(10, s);
s++;
}
}
mst.add(new Measurement(new String(city), d));
}
}
st = i + 1;
}
}
// System.out.println("Task " + kk + "Completed in " + (System.currentTimeMillis() - start));
measurements.addAll(mst.stream()
.collect(groupingBy(Measurement::station, collector))
.values());
}
catch (Exception e) {
// throw new RuntimeException(e);
System.out.println("");
}
}
};
}
}