using unsafe alone (#512)

* final comit

changing using mappedbytebuffer

changes before using unsafe address

using unsafe

* using graalvm,correct unsafe mem implementation

---------

Co-authored-by: Karthikeyans <karthikeyan.sn@zohocorp.com>
This commit is contained in:
karthikeyan97 2024-01-21 01:19:54 +05:30 committed by GitHub
parent ac26c8b644
commit f49a92019e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 114 additions and 109 deletions

View File

@ -16,4 +16,14 @@
# #
JAVA_OPTS="-Xms20480m -Xmx40960m " JAVA_OPTS="-Xms20480m -Xmx40960m "
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_karthikeyan97
if [ -f target/CalculateAverage_karthikeyan97_image ]; then
#echo "Picking up existing native image 'target/CalculateAverage_karthikeyan97_image', delete the file to select JVM mode." 1>&2
target/CalculateAverage_karthikeyan97_image -Xms20480m -Xmx32768m
else
JAVA_OPTS="--enable-preview"
#echo "Chosing to run the app in JVM mode as no native image was found, use prepare_karthikeyan97.sh to generate." 1>&2
java -Xms20480m -Xmx32768m --enable-preview --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_karthikeyan97
fi

View File

@ -22,9 +22,12 @@ import static java.util.stream.Collectors.*;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.RandomAccessFile; import java.io.RandomAccessFile;
import java.lang.foreign.Arena;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -33,6 +36,7 @@ import java.util.HashMap;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Scanner;
import java.util.Set; import java.util.Set;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
@ -44,8 +48,21 @@ import java.util.stream.Collectors;
public class CalculateAverage_karthikeyan97 { public class CalculateAverage_karthikeyan97 {
private static final Unsafe UNSAFE = initUnsafe();
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static Unsafe initUnsafe() {
try {
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
return (Unsafe) theUnsafe.get(Unsafe.class);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
private record Measurement(modifiedbytearray station, double value) { private record Measurement(modifiedbytearray station, double value) {
} }
@ -53,18 +70,18 @@ public class CalculateAverage_karthikeyan97 {
} }
private static class MeasurementAggregator { private static class MeasurementAggregator {
private double min = Double.POSITIVE_INFINITY; private long min = Long.MAX_VALUE;
private double max = Double.NEGATIVE_INFINITY; private long max = Long.MIN_VALUE;
private long sum; private long sum;
private long count; private long count;
public String toString() { public String toString() {
return new StringBuffer(14) return new StringBuffer(14)
.append(round(min)) .append(round((1.0 * min)))
.append("/") .append("/")
.append(round((1.0 * sum) / count)) .append(round((1.0 * sum) / count))
.append("/") .append("/")
.append(round(max)).toString(); .append(round((1.0 * max))).toString();
} }
private double round(double value) { private double round(double value) {
@ -74,7 +91,7 @@ public class CalculateAverage_karthikeyan97 {
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
// long start = System.nanoTime(); // long start = System.nanoTime();
System.setSecurityManager(null); // System.setSecurityManager(null);
Collector<Map.Entry<modifiedbytearray, MeasurementAggregator>, MeasurementAggregator, MeasurementAggregator> collector = Collector.of( Collector<Map.Entry<modifiedbytearray, MeasurementAggregator>, MeasurementAggregator, MeasurementAggregator> collector = Collector.of(
MeasurementAggregator::new, MeasurementAggregator::new,
(a, m) -> { (a, m) -> {
@ -103,15 +120,17 @@ public class CalculateAverage_karthikeyan97 {
}, },
agg -> agg); agg -> agg);
RandomAccessFile raf = new RandomAccessFile(FILE, "rw"); RandomAccessFile raf = new RandomAccessFile(FILE, "r");
FileChannel fileChannel = raf.getChannel();
final long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, raf.length(), Arena.global()).address();
long length = raf.length(); long length = raf.length();
int cores = length > 1000 ? Runtime.getRuntime().availableProcessors() : 1; final long endAddress = mappedAddress + length - 1;
int cores = length > 1000 ? Runtime.getRuntime().availableProcessors() * 2 : 1;
long boundary[][] = new long[cores][2]; long boundary[][] = new long[cores][2];
long segments = length / (cores); long segments = length / (cores);
long before = -1; long before = -1;
for (int i = 0; i < cores - 1; i++) { for (int i = 0; i < cores - 1; i++) {
boundary[i][0] = before + 1; boundary[i][0] = before + 1;
byte[] b = new byte[107];
if (before + segments - 107 > 0) { if (before + segments - 107 > 0) {
raf.seek(before + segments - 107); raf.seek(before + segments - 107);
} }
@ -130,120 +149,92 @@ public class CalculateAverage_karthikeyan97 {
f.setAccessible(true); f.setAccessible(true);
Unsafe unsafe = (Unsafe) f.get(null); Unsafe unsafe = (Unsafe) f.get(null);
int pageSize = unsafe.pageSize() * 10; int l3Size = (13 * 1024 * 1024);// unsafe.l3Size();
System.out.println(new TreeMap((Arrays.stream(boundary).parallel().map(i -> { System.out.println(new TreeMap((Arrays.stream(boundary).parallel().map(i -> {
FileInputStream fileInputStream = null; FileInputStream fileInputStream = null;
try { try {
fileInputStream = new FileInputStream(FILE); int seglen = (int) (i[1] - i[0] + 1);
FileChannel fileChannel = fileInputStream.getChannel(); HashMap<modifiedbytearray, MeasurementAggregator> resultmap = new HashMap<>(1000);
HashMap<modifiedbytearray, MeasurementAggregator> resultmap = new HashMap<>(12000, 100); long segstart = mappedAddress + i[0];
int bytesRemaining = seglen;
ByteBuffer buffer = ByteBuffer.allocateDirect(pageSize); long num = 0;
fileChannel.position(i[0]);
int bytesReading = 0;
double num = 0;
int sign = 1; int sign = 1;
boolean isNumber = false; boolean isNumber = false;
byte bi; byte bi;
modifiedbytearray stationName = null; modifiedbytearray stationName = null;
int hascode = 1; int hascode = 5381;
int ctr = 0; while (bytesRemaining > 0) {
byte[] arr = new byte[100];
int arrptr = 0;
int seglen = (int) (i[1] - i[0] + 1);
while (bytesReading < seglen) {
buffer.clear();
int bytesRead = fileChannel.read(buffer);
if ((bytesReading + bytesRead) <= seglen) {
if (bytesRead < 0) {
bytesRead = 0;
}
}
else {
bytesRead = (seglen - bytesReading);
}
buffer.flip();
int bytesptr = 0; int bytesptr = 0;
byte[] bufferArr = new byte[bytesRead]; // int bytesread = buffer.remaining() > l3Size ? l3Size : buffer.remaining();
buffer.get(bufferArr); // byte[] bufferArr = new byte[bytesread];
while (bytesptr < bytesRead) { // buffer.get(bufferArr);
bytesReading += 1; int bbstart = 0;
bi = bufferArr[bytesptr++]; int readSize = bytesRemaining > l3Size ? l3Size : bytesRemaining;
if (ctr > 0) { int actualReadSize = (segstart + readSize + 110 > endAddress || readSize + 110 > i[1]) ? readSize : readSize + 110;
arr[arrptr++] = bi; byte[] readArr = new byte[actualReadSize];
hascode = 31 * hascode + bi;
ctr--; UNSAFE.copyMemory(null, segstart, readArr, UNSAFE.ARRAY_BYTE_BASE_OFFSET, actualReadSize);
} while (bytesptr < actualReadSize) {
else { bi = readArr[bytesptr++];// UNSAFE.getByte(segstart + bytesReading++);
if (bi >= 240) { if (!isNumber) {
arr[arrptr++] = bi; if (bi >= 192) {
hascode = 31 * hascode + bi; hascode = (hascode << 5) + hascode ^ bi;
ctr = 3;
}
else if (bi >= 224) {
arr[arrptr++] = bi;
hascode = 31 * hascode + bi;
ctr = 2;
}
else if (bi >= 192) {
arr[arrptr++] = bi;
hascode = 31 * hascode + bi;
ctr = 1;
} }
else if (bi == 59) { else if (bi == 59) {
isNumber = true; isNumber = true;
stationName = new modifiedbytearray(arr, arrptr, hascode); stationName = new modifiedbytearray(readArr, bbstart, bytesptr - 2, hascode & 0xFFFFFFFF);
arr = new byte[100]; bbstart = 0;
arrptr = 0; hascode = 5381;
hascode = 1; if (bytesptr >= readSize) {
} break;
else if (bi == 10) {
hascode = 1;
isNumber = false;
MeasurementAggregator agg = resultmap.get(stationName);
num *= sign;
if (agg == null) {
agg = new MeasurementAggregator();
agg.min = num;
agg.max = num;
agg.sum = (long) (num);
agg.count = 1;
resultmap.put(stationName, agg);
} }
else {
if (agg.min >= num) {
agg.min = num;
}
if (agg.max <= num) {
agg.max = num;
}
agg.sum += (long) (num);
agg.count++;
}
num = 0;
sign = 1;
} }
else { else {
hascode = 31 * hascode + bi; hascode = (hascode << 5) + hascode ^ bi;
if (isNumber) { }
switch (bi) { }
case 0x2E: else {
break; switch (bi) {
case 0x2D: case 0x2E:
sign = -1; break;
break; case 0x2D:
default: sign = -1;
num = num * 10 + (bi - 0x30); break;
case 10:
hascode = 5381;
isNumber = false;
bbstart = bytesptr;
MeasurementAggregator agg = resultmap.get(stationName);
num *= sign;
if (agg == null) {
agg = new MeasurementAggregator();
agg.min = num;
agg.max = num;
agg.sum = (long) (num);
agg.count = 1;
resultmap.put(stationName, agg);
} }
} else {
else { if (agg.min >= num) {
arr[arrptr++] = bi; agg.min = num;
} }
if (agg.max <= num) {
agg.max = num;
}
agg.sum += (long) (num);
agg.count++;
}
num = 0;
sign = 1;
break;
default:
num = num * 10 + (bi - 0x30);
} }
} }
} }
bytesRemaining -= bytesptr;
segstart += bytesptr;
} }
/* /*
* while (bytesReading < (i[1] - i[0] + 1) && buffer.position() < buffer.limit()) { * while (bytesReading < (i[1] - i[0] + 1) && buffer.position() < buffer.limit()) {
@ -335,7 +326,7 @@ public class CalculateAverage_karthikeyan97 {
*/ */
// Get the FileChannel from the FileInputStream // Get the FileChannel from the FileInputStream
// System.out.println("time taken:" + (System.nanoTime() - start) / 1000000); // System.out.println("time taken1:" + (System.nanoTime() - start) / 1000000);
// System.out.println(measurements); // System.out.println(measurements);
} }
@ -343,17 +334,21 @@ public class CalculateAverage_karthikeyan97 {
class modifiedbytearray { class modifiedbytearray {
private int length; private int length;
private int start;
private int end;
private byte[] arr; private byte[] arr;
public int hashcode; public int hashcode;
modifiedbytearray(byte[] arr, int length, int hashcode) { modifiedbytearray(byte[] arr, int start, int end, int hashcode) {
this.arr = arr; this.arr = arr;
this.length = length; this.length = end - start + 1;
this.end = end;
this.start = start;
this.hashcode = hashcode; this.hashcode = hashcode;
} }
public String getStationName() { public String getStationName() {
return new String(this.getArr(), 0, length, StandardCharsets.UTF_8); return new String(this.getArr(), start, length, StandardCharsets.UTF_8);
} }
public byte[] getArr() { public byte[] getArr() {
@ -368,7 +363,7 @@ class modifiedbytearray {
@Override @Override
public boolean equals(Object obj) { public boolean equals(Object obj) {
modifiedbytearray b = (modifiedbytearray) obj; modifiedbytearray b = (modifiedbytearray) obj;
return Arrays.equals(this.getArr(), 0, length, b.arr, 0, b.length); return Arrays.equals(this.getArr(), start, end, b.arr, b.start, b.end);
} }
public int getHashcode() { public int getHashcode() {