My Probably last attempt to optimize performance (#693)

* CalculateAverage_pdrakatos

* Rename to be valid with rules

* CalculateAverage_pdrakatos

* Rename to be valid with rules

* Changes on scripts execution

* Fixing bugs causing scripts not to be executed

* Changes on prepare make it compatible

* Fixing passing all tests

* Increase direct memory allocation buffer

* Fixing memory problem causes heap space exception

* Fresh solution to optimize performance of the execution

* New Fresh solution with optimized performance with Custom Hashtable

* Increase maxperm size and xmx to avoid heap spaces error
This commit is contained in:
Panagiotis Drakatos 2024-02-01 13:02:45 +02:00 committed by GitHub
parent 1e7314d5fb
commit 2aed039f17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 236 additions and 58 deletions

View File

@ -32,5 +32,5 @@
# #
source "$HOME/.sdkman/bin/sdkman-init.sh" source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.1-graal 1>&2 sdk use java 21.0.1-graal 1>&2
JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA" JAVA_OPTS="--enable-preview -Xms1536m -Xmx10536m -XX:NewSize=256m -XX:MaxNewSize=512m -XX:MaxMetaspaceSize=512m -XX:+DisableExplicitGC -XX:+UseSerialGC -XX:-TieredCompilation -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos

View File

@ -18,6 +18,6 @@ source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.1-graal 1>&2 sdk use java 21.0.1-graal 1>&2
if [ ! -f target/CalculateAverage_PanagiotisDrakatos_image ]; then if [ ! -f target/CalculateAverage_PanagiotisDrakatos_image ]; then
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -R:MaxHeapSize=64m --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos" NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -R:MaxHeapSize=10536m --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos"
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_PanagiotisDrakatos_image dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_PanagiotisDrakatos_image dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos
fi fi

View File

@ -20,41 +20,38 @@ import java.io.FileInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.RandomAccessFile; import java.io.RandomAccessFile;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer; import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.util.*; import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import java.util.stream.StreamSupport;
public class CalculateAverage_PanagiotisDrakatos { public class CalculateAverage_PanagiotisDrakatos {
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static final long SEGMENT_SIZE = 4 * 1024 * 1024; private static final long MAP_SIZE = 1024 * 1024 * 12L;
private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL;
private static final long DOT_BITS = 0x10101000;
private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
private static TreeMap<String, MeasurementObject> sortedCities; private static TreeMap<String, MeasurementObject> sortedCities;
public static void main(String[] args) throws IOException { public static void main(String[] args) throws IOException {
SeekableByteRead(FILE); SeekableByteRead(FILE);
System.out.println(sortedCities); System.out.println(sortedCities.toString());
boolean DEBUG = true; boolean DEBUG = true;
} }
private static void SeekableByteRead(String path) throws IOException { private static void SeekableByteRead(String path) throws IOException {
FileInputStream fileInputStream = new FileInputStream(new File(FILE)); FileInputStream fileInputStream = new FileInputStream(new File(FILE));
FileChannel fileChannel = fileInputStream.getChannel(); FileChannel fileChannel = fileInputStream.getChannel();
Optional<Map<String, MeasurementObject>> optimistic = getFileSegments(new File(FILE), fileChannel) try {
.stream() sortedCities = getFileSegments(new File(FILE), fileChannel).stream()
.map(CalculateAverage_PanagiotisDrakatos::SplitSeekableByteChannel) .map(CalculateAverage_PanagiotisDrakatos::SplitSeekableByteChannel)
.parallel() .parallel()
.map(CalculateAverage_PanagiotisDrakatos::MappingByteBufferToData) .map(CalculateAverage_PanagiotisDrakatos::MappingByteBufferToData)
.reduce(CalculateAverage_PanagiotisDrakatos::combineMaps); .flatMap(MeasurementRepository::get)
.collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, MeasurementObject::updateWith, TreeMap::new));
}
catch (NullPointerException e) {
}
fileChannel.close(); fileChannel.close();
sortedCities = new TreeMap<>(optimistic.orElseThrow());
} }
record FileSegment(long start, long end, FileChannel fileChannel) { record FileSegment(long start, long end, FileChannel fileChannel) {
@ -95,14 +92,40 @@ public class CalculateAverage_PanagiotisDrakatos {
private static ByteBuffer SplitSeekableByteChannel(FileSegment segment) { private static ByteBuffer SplitSeekableByteChannel(FileSegment segment) {
try { try {
MappedByteBuffer buffer = segment.fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segment.end - segment.start()); MappedByteBuffer buffer = segment.fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segment.end - segment.start());
int end = buffer.limit() - 1; return buffer;
while (buffer.get(end) != '\n') {
end--;
}
return buffer.slice(0, end);
} }
catch (Exception ex) { catch (Exception ex) {
throw new RuntimeException(ex); long start = segment.start;
long end = 0;
try {
end = segment.fileChannel.size();
}
catch (IOException e) {
throw new RuntimeException(e);
}
MappedByteBuffer buffer = null;
ArrayList<ByteBuffer> list = new ArrayList<>();
while (start < end) {
try {
buffer = segment.fileChannel.map(FileChannel.MapMode.READ_ONLY, start, Math.min(MAP_SIZE, end - start));
// don't split the data in the middle of lines
// find the closest previous newline
int realEnd = buffer.limit() - 1;
while (buffer.get(realEnd) != '\n')
realEnd--;
realEnd++;
buffer.limit(realEnd);
start += realEnd;
list.add(buffer.slice(0, realEnd - 1));
}
catch (Exception e) {
e.printStackTrace();
}
}
sortedCities = list.stream().parallel().map(CalculateAverage_PanagiotisDrakatos::MappingByteBufferToData).flatMap(MeasurementRepository::get)
.collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, MeasurementObject::updateWith, TreeMap::new));
return null;
} }
} }
@ -121,38 +144,61 @@ public class CalculateAverage_PanagiotisDrakatos {
return all; return all;
} }
private static Map<String, MeasurementObject> combineMaps(Map<String, MeasurementObject> map1, Map<String, MeasurementObject> map2) { private static TreeMap<String, MeasurementObject> combineMaps(Stream<MeasurementRepository.Entry> stream1, Stream<MeasurementRepository.Entry> stream2) {
for (var entry : map2.entrySet()) { Stream<MeasurementRepository.Entry> resultingStream = Stream.concat(stream1, stream2);
map1.merge(entry.getKey(), entry.getValue(), MeasurementObject::combine); return resultingStream.collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, MeasurementObject::updateWith, TreeMap::new));
}
return map1;
} }
private static Map<String, MeasurementObject> MappingByteBufferToData(ByteBuffer byteBuffer) { private static int longHashStep(final int hash, final long word) {
Map<String, MeasurementObject> cities = new HashMap<>(); return 31 * hash + (int) (word ^ (word >>> 32));
}
private static final long SEPARATOR_PATTERN = compilePattern((byte) ';');
private static long compilePattern(final byte value) {
return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) | ((long) value << 24) | ((long) value << 16)
| ((long) value << 8) | (long) value;
}
private static MeasurementRepository MappingByteBufferToData(ByteBuffer byteBuffer) {
MeasurementRepository measurements = new MeasurementRepository();
ByteBuffer bb = byteBuffer.duplicate(); ByteBuffer bb = byteBuffer.duplicate();
int start = 0; int start = 0;
int end = 0; int limit = bb.limit();
while (start < bb.limit()) {
while (bb.get(end) != ';') { long[] cityNameAsLongArray = new long[16];
end++; int[] delimiterPointerAndHash = new int[2];
bb.order(ByteOrder.nativeOrder());
final boolean bufferIsBigEndian = bb.order().equals(ByteOrder.BIG_ENDIAN);
while ((start = bb.position()) < limit + 1) {
int delimiterPointer;
findNextDelimiterAndCalculateHash(bb, SEPARATOR_PATTERN, start, limit, delimiterPointerAndHash, cityNameAsLongArray, bufferIsBigEndian);
delimiterPointer = delimiterPointerAndHash[0];
// Simple lookup is faster for '\n' (just three options)
if (delimiterPointer >= limit) {
return measurements;
} }
final int cityNameLength = delimiterPointer - start;
int temp_counter = 0; int temp_counter = 0;
int temp_end = end; int temp_end = delimiterPointer + 1;
try { try {
bb.position(end); // bb.position(delimiterPointer++);
while (bb.get(temp_end) != '\n') { while (bb.get(temp_end) != '\n') {
temp_counter++; temp_counter++;
temp_end++; temp_end++;
} }
} }
catch (IndexOutOfBoundsException e) { catch (IndexOutOfBoundsException e) {
temp_counter--; // temp_counter--;
temp_end--; // temp_end--;
} }
ByteBuffer city = bb.slice(start, end - start); ByteBuffer temp = bb.duplicate().slice(delimiterPointer + 1, temp_counter);
ByteBuffer temp = bb.slice(end + 1, temp_counter);
int tempPointer = 0; int tempPointer = 0;
int abs = 1; int abs = 1;
if (temp.get(0) == '-') { if (temp.get(0) == '-') {
@ -167,22 +213,141 @@ public class CalculateAverage_PanagiotisDrakatos {
measuredValue = abs * (temp.get(tempPointer) * 100 + temp.get(tempPointer + 1) * 10 + temp.get(tempPointer + 3) - 5328); measuredValue = abs * (temp.get(tempPointer) * 100 + temp.get(tempPointer + 1) * 10 + temp.get(tempPointer + 3) - 5328);
} }
byte[] citybytes = new byte[city.limit()]; measurements.update(cityNameAsLongArray, bb, cityNameLength, delimiterPointerAndHash[1]).updateWith(measuredValue);
city.get(citybytes);
String cityName = new String(citybytes, StandardCharsets.UTF_8);
// update the map with the new measurement if (temp_end + 1 > limit)
MeasurementObject agg = cities.get(cityName); return measurements;
if (agg == null) { bb.position(temp_end + 1);
cities.put(cityName, new MeasurementObject(measuredValue, measuredValue, 0, 0).updateWith(measuredValue));
}
else {
cities.put(cityName, agg.updateWith(measuredValue));
}
start = temp_end + 1;
end = temp_end;
} }
return cities; return measurements;
}
private static void findNextDelimiterAndCalculateHash(final ByteBuffer bb, final long pattern, final int start, final int limit, final int[] output,
final long[] asLong, final boolean bufferBigEndian) {
int hash = 1;
int i;
int lCnt = 0;
for (i = start; i <= limit - 8; i += 8) {
long word = bb.getLong(i);
if (bufferBigEndian) {
word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this
}
final long match = word ^ pattern;
long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L;
if (mask != 0) {
final int index = Long.numberOfTrailingZeros(mask) >> 3;
output[0] = (i + index);
final long partialHash = word & ((mask >> 7) - 1);
asLong[lCnt] = partialHash;
output[1] = longHashStep(hash, partialHash);
return;
}
asLong[lCnt++] = word;
hash = longHashStep(hash, word);
}
// Handle remaining bytes near the limit of the buffer:
long partialHash = 0;
int len = 0;
for (; i < limit; i++) {
byte read;
if ((read = bb.get(i)) == (byte) pattern) {
asLong[lCnt] = partialHash;
output[0] = i;
output[1] = longHashStep(hash, partialHash);
return;
}
partialHash = partialHash | ((long) read << (len << 3));
len++;
}
output[0] = limit; // delimiter not found
}
static class MeasurementRepository {
private int tableSize = 1 << 20; // can grow in theory, made large enough not to (this is faster)
private int tableMask = (tableSize - 1);
private int tableLimit = (int) (tableSize * LOAD_FACTOR);
private int tableFilled = 0;
private static final float LOAD_FACTOR = 0.8f;
private Entry[] table = new Entry[tableSize];
record Entry(int hash, long[] nameBytesInLong, String cityName, MeasurementObject measurement) {
@Override
public String toString() {
return cityName + "=" + measurement;
}
}
public MeasurementObject update(long[] nameBytesInLong, ByteBuffer bb, int length, int calculatedHash) {
final int nameBytesInLongLength = 1 + (length >>> 3);
int index = calculatedHash & tableMask;
Entry tableEntry;
while ((tableEntry = table[index]) != null
&& (tableEntry.hash != calculatedHash || !arrayEquals(tableEntry.nameBytesInLong, nameBytesInLong, nameBytesInLongLength))) { // search for the right spot
index = (index + 1) & tableMask;
}
if (tableEntry != null) {
return tableEntry.measurement;
}
// --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here.
MeasurementObject measurement = new MeasurementObject();
// Now create a string:
byte[] buffer = new byte[length];
bb.get(buffer, 0, length);
String cityName = new String(buffer, 0, length);
// Store the long[] for faster equals:
long[] nameBytesInLongCopy = new long[nameBytesInLongLength];
System.arraycopy(nameBytesInLong, 0, nameBytesInLongCopy, 0, nameBytesInLongLength);
// And add entry:
Entry toAdd = new Entry(calculatedHash, nameBytesInLongCopy, cityName, measurement);
table[index] = toAdd;
// Resize the table if filled too much:
if (++tableFilled > tableLimit) {
resizeTable();
}
return toAdd.measurement;
}
private void resizeTable() {
// Resize the table:
Entry[] oldEntries = table;
table = new Entry[tableSize <<= 2]; // x2
tableMask = (tableSize - 1);
tableLimit = (int) (tableSize * LOAD_FACTOR);
for (Entry entry : oldEntries) {
if (entry != null) {
int updatedTableIndex = entry.hash & tableMask;
while (table[updatedTableIndex] != null) {
updatedTableIndex = (updatedTableIndex + 1) & tableMask;
}
table[updatedTableIndex] = entry;
}
}
}
public Stream<Entry> get() {
return Arrays.stream(table).filter(Objects::nonNull);
}
}
private static boolean arrayEquals(final long[] a, final long[] b, final int length) {
for (int i = 0; i < length; i++) {
if (a[i] != b[i])
return false;
}
return true;
} }
private static final class MeasurementObject { private static final class MeasurementObject {
@ -202,6 +367,10 @@ public class CalculateAverage_PanagiotisDrakatos {
} }
public MeasurementObject() { public MeasurementObject() {
this.MAX = -999;
this.MIN = 9999;
this.SUM = 0;
this.REPEAT = 0;
} }
public MeasurementObject(int MAX, int MIN, long SUM) { public MeasurementObject(int MAX, int MIN, long SUM) {
@ -224,6 +393,15 @@ public class CalculateAverage_PanagiotisDrakatos {
return mres; return mres;
} }
public static MeasurementObject updateWith(MeasurementObject m1, MeasurementObject m2) {
var mres = new MeasurementObject();
mres.MIN = MeasurementObject.min(m1.MIN, m2.MIN);
mres.MAX = MeasurementObject.max(m1.MAX, m2.MAX);
mres.SUM = m1.SUM + m2.SUM;
mres.REPEAT = m1.REPEAT + m2.REPEAT;
return mres;
}
public MeasurementObject updateWith(int measurement) { public MeasurementObject updateWith(int measurement) {
MIN = MeasurementObject.min(MIN, measurement); MIN = MeasurementObject.min(MIN, measurement);
MAX = MeasurementObject.max(MAX, measurement); MAX = MeasurementObject.max(MAX, measurement);
@ -268,4 +446,4 @@ public class CalculateAverage_PanagiotisDrakatos {
return round(MIN) + "/" + round((1.0 * SUM) / REPEAT) + "/" + round(MAX); return round(MIN) + "/" + round((1.0 * SUM) / REPEAT) + "/" + round(MAX);
} }
} }
} }