improved 2nd and final submission (#685)
This commit is contained in:
parent
101993f06d
commit
75bece5364
@ -1,4 +1,4 @@
|
|||||||
#!/bin/sh
|
#!/bin/bash
|
||||||
#
|
#
|
||||||
# Copyright 2023 The original authors
|
# Copyright 2023 The original authors
|
||||||
#
|
#
|
||||||
@ -19,5 +19,8 @@
|
|||||||
# source "$HOME/.sdkman/bin/sdkman-init.sh"
|
# source "$HOME/.sdkman/bin/sdkman-init.sh"
|
||||||
# sdk use java 21.0.1-graal 1>&2
|
# sdk use java 21.0.1-graal 1>&2
|
||||||
|
|
||||||
JAVA_OPTS="--enable-preview --enable-native-access=ALL-UNNAMED --add-modules jdk.incubator.vector"
|
JAVA_OPTS="-Xlog:all=off -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 --enable-preview --enable-native-access=ALL-UNNAMED --add-modules jdk.incubator.vector"
|
||||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yourwass
|
|
||||||
|
eval "exec 3< <({ java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yourwass; })"
|
||||||
|
read <&3 result
|
||||||
|
echo -e "$result"
|
||||||
|
@ -16,6 +16,8 @@
|
|||||||
package dev.morling.onebrc;
|
package dev.morling.onebrc;
|
||||||
|
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
|
import java.util.concurrent.locks.Lock;
|
||||||
|
import java.util.concurrent.locks.ReentrantLock;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.lang.foreign.Arena;
|
import java.lang.foreign.Arena;
|
||||||
import java.lang.foreign.MemorySegment;
|
import java.lang.foreign.MemorySegment;
|
||||||
@ -31,18 +33,15 @@ import jdk.incubator.vector.VectorSpecies;
|
|||||||
import sun.misc.Unsafe;
|
import sun.misc.Unsafe;
|
||||||
|
|
||||||
public class CalculateAverage_yourwass {
|
public class CalculateAverage_yourwass {
|
||||||
|
|
||||||
static final class Record {
|
static final class Record {
|
||||||
public String city;
|
private long cityAddr;
|
||||||
public long cityAddr;
|
private long cityLength;
|
||||||
public long cityLength;
|
private int min;
|
||||||
public int min;
|
private int max;
|
||||||
public int max;
|
private int count;
|
||||||
public int count;
|
private long sum;
|
||||||
public long sum;
|
|
||||||
|
|
||||||
Record(final long cityAddr, final long cityLength) {
|
Record(final long cityAddr, final long cityLength) {
|
||||||
this.city = null;
|
|
||||||
this.cityAddr = cityAddr;
|
this.cityAddr = cityAddr;
|
||||||
this.cityLength = cityLength;
|
this.cityLength = cityLength;
|
||||||
this.min = 1000;
|
this.min = 1000;
|
||||||
@ -62,6 +61,8 @@ public class CalculateAverage_yourwass {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private final static Lock _mutex = new ReentrantLock(true);
|
||||||
|
private final static TreeMap<String, Record> aggregateResults = new TreeMap<>();
|
||||||
private static short lookupDecimal[];
|
private static short lookupDecimal[];
|
||||||
private static byte lookupFraction[];
|
private static byte lookupFraction[];
|
||||||
private static byte lookupDotPositive[];
|
private static byte lookupDotPositive[];
|
||||||
@ -70,6 +71,8 @@ public class CalculateAverage_yourwass {
|
|||||||
private static final VectorSpecies<Byte> SPECIES = ByteVector.SPECIES_PREFERRED;
|
private static final VectorSpecies<Byte> SPECIES = ByteVector.SPECIES_PREFERRED;
|
||||||
private static final int MAXINDEX = (1 << 16) + 10000; // short hash + max allowed cities for collisions at the end :p
|
private static final int MAXINDEX = (1 << 16) + 10000; // short hash + max allowed cities for collisions at the end :p
|
||||||
private static final String FILE = "measurements.txt";
|
private static final String FILE = "measurements.txt";
|
||||||
|
private static long unsafeResults;
|
||||||
|
private static int RECORDSIZE = 36;
|
||||||
private static final Unsafe UNSAFE = getUnsafe();
|
private static final Unsafe UNSAFE = getUnsafe();
|
||||||
|
|
||||||
private static Unsafe getUnsafe() {
|
private static Unsafe getUnsafe() {
|
||||||
@ -113,11 +116,9 @@ public class CalculateAverage_yourwass {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// open file
|
// open file
|
||||||
final long fileSize, mmapAddr;
|
final FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
|
||||||
try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
|
final long fileSize = fileChannel.size();
|
||||||
fileSize = fileChannel.size();
|
final long mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
|
||||||
mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
|
|
||||||
}
|
|
||||||
// VAS: Virtual Address Space, as a MemorySegment upto and including the mmaped file.
|
// VAS: Virtual Address Space, as a MemorySegment upto and including the mmaped file.
|
||||||
// If the mmaped MemorySegment is used for Vector creation as is, then there are two problems:
|
// If the mmaped MemorySegment is used for Vector creation as is, then there are two problems:
|
||||||
// 1) fromMemorySegment takes an offset and not an address, so we have to do arithmetic
|
// 1) fromMemorySegment takes an offset and not an address, so we have to do arithmetic
|
||||||
@ -127,36 +128,24 @@ public class CalculateAverage_yourwass {
|
|||||||
// XXX there lies the possibility for an out of bounds read at the end of file, which is not handled here.
|
// XXX there lies the possibility for an out of bounds read at the end of file, which is not handled here.
|
||||||
VAS = MemorySegment.ofAddress(0).reinterpret(mmapAddr + fileSize + SPECIES.length());
|
VAS = MemorySegment.ofAddress(0).reinterpret(mmapAddr + fileSize + SPECIES.length());
|
||||||
|
|
||||||
// start and wait for threads to finish
|
// allocate memory for results
|
||||||
final int nThreads = Runtime.getRuntime().availableProcessors();
|
final int nThreads = Runtime.getRuntime().availableProcessors();
|
||||||
|
unsafeResults = UNSAFE.allocateMemory(RECORDSIZE * MAXINDEX * nThreads);
|
||||||
|
UNSAFE.setMemory(unsafeResults, RECORDSIZE * MAXINDEX * nThreads, (byte) 0);
|
||||||
|
|
||||||
|
// start and wait for threads to finish
|
||||||
Thread[] threadList = new Thread[nThreads];
|
Thread[] threadList = new Thread[nThreads];
|
||||||
final Record[][] results = new Record[nThreads][];
|
|
||||||
final long chunkSize = fileSize / nThreads;
|
final long chunkSize = fileSize / nThreads;
|
||||||
for (int i = 0; i < nThreads; i++) {
|
for (int i = 0; i < nThreads; i++) {
|
||||||
final int threadIndex = i;
|
final int threadIndex = i;
|
||||||
final long startAddr = mmapAddr + i * chunkSize;
|
final long startAddr = mmapAddr + i * chunkSize;
|
||||||
final long endAddr = (i == nThreads - 1) ? mmapAddr + fileSize : mmapAddr + (i + 1) * chunkSize;
|
final long endAddr = (i == nThreads - 1) ? mmapAddr + fileSize : mmapAddr + (i + 1) * chunkSize;
|
||||||
threadList[i] = new Thread(() -> results[threadIndex] = threadMain(threadIndex, startAddr, endAddr, nThreads));
|
threadList[i] = new Thread(() -> threadMain(threadIndex, startAddr, endAddr, nThreads));
|
||||||
threadList[i].start();
|
threadList[i].start();
|
||||||
}
|
}
|
||||||
for (int i = 0; i < nThreads; i++)
|
for (int i = 0; i < nThreads; i++)
|
||||||
threadList[i].join();
|
threadList[i].join();
|
||||||
|
|
||||||
// aggregate results and sort
|
|
||||||
// TODO have to compare with concurrent-parallel stream structures:
|
|
||||||
// * concurrent hashtable that have to sort afterwards
|
|
||||||
// * concurrent skiplist that is sorted but has O(n) insert
|
|
||||||
// * ..other?
|
|
||||||
final TreeMap<String, Record> aggregateResults = new TreeMap<>();
|
|
||||||
for (int thread = 0; thread < nThreads; thread++) {
|
|
||||||
for (int index = 0; index < MAXINDEX; index++) {
|
|
||||||
Record record = results[thread][index];
|
|
||||||
if (record == null)
|
|
||||||
continue;
|
|
||||||
aggregateResults.compute(record.city, (k, v) -> (v == null) ? record : v.merge(record));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare string and print
|
// prepare string and print
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
sb.append("{");
|
sb.append("{");
|
||||||
@ -167,12 +156,13 @@ public class CalculateAverage_yourwass {
|
|||||||
float max = record.max;
|
float max = record.max;
|
||||||
max /= 10.f;
|
max /= 10.f;
|
||||||
double avg = Math.round((record.sum * 1.0) / record.count) / 10.;
|
double avg = Math.round((record.sum * 1.0) / record.count) / 10.;
|
||||||
sb.append(record.city).append("=").append(min).append("/").append(avg).append("/").append(max).append(", ");
|
sb.append(entry.getKey()).append("=").append(min).append("/").append(avg).append("/").append(max).append(", ");
|
||||||
}
|
}
|
||||||
int stringLength = sb.length();
|
int stringLength = sb.length();
|
||||||
sb.setCharAt(stringLength - 2, '}');
|
sb.setCharAt(stringLength - 2, '}');
|
||||||
sb.setCharAt(stringLength - 1, '\n');
|
sb.setCharAt(stringLength - 1, '\n');
|
||||||
System.out.print(sb.toString());
|
System.out.print(sb.toString());
|
||||||
|
System.out.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final boolean citiesDiffer(final long a, final long b, final long len) {
|
private static final boolean citiesDiffer(final long a, final long b, final long len) {
|
||||||
@ -185,7 +175,7 @@ public class CalculateAverage_yourwass {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Record[] threadMain(int id, long startAddr, long endAddr, long nThreads) {
|
private static void threadMain(int id, long startAddr, long endAddr, long nThreads) {
|
||||||
// snap to newlines
|
// snap to newlines
|
||||||
if (id != 0)
|
if (id != 0)
|
||||||
while (UNSAFE.getByte(startAddr++) != '\n')
|
while (UNSAFE.getByte(startAddr++) != '\n')
|
||||||
@ -194,23 +184,24 @@ public class CalculateAverage_yourwass {
|
|||||||
while (UNSAFE.getByte(endAddr++) != '\n')
|
while (UNSAFE.getByte(endAddr++) != '\n')
|
||||||
;
|
;
|
||||||
|
|
||||||
|
final long threadResults = unsafeResults + id * MAXINDEX * RECORDSIZE;
|
||||||
final Record[] results = new Record[MAXINDEX];
|
final Record[] results = new Record[MAXINDEX];
|
||||||
final long VECTORBYTESIZE = SPECIES.length();
|
final long VECTORBYTESIZE = SPECIES.length();
|
||||||
final ByteOrder BYTEORDER = ByteOrder.nativeOrder();
|
final ByteOrder BYTEORDER = ByteOrder.nativeOrder();
|
||||||
final ByteVector delim = ByteVector.broadcast(SPECIES, ';');
|
final ByteVector delim = ByteVector.broadcast(SPECIES, ';');
|
||||||
long nextCityAddr = startAddr; // XXX from these three variables,
|
long cityAddr = startAddr;
|
||||||
long cityAddr = nextCityAddr; // only two are necessary, but if one
|
long ptr = 0;
|
||||||
long ptr = 0; // is eliminated, on my pc the benchmark gets worse..
|
while (cityAddr < endAddr) {
|
||||||
while (nextCityAddr < endAddr) {
|
|
||||||
// parse city
|
// parse city
|
||||||
long mask = ByteVector.fromMemorySegment(SPECIES, VAS, nextCityAddr + ptr, BYTEORDER)
|
ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER);
|
||||||
.compare(VectorOperators.EQ, delim).toLong();
|
long mask = parsed.compare(VectorOperators.EQ, delim).toLong();
|
||||||
if (mask == 0) {
|
while (mask == 0) {
|
||||||
ptr += VECTORBYTESIZE;
|
ptr += VECTORBYTESIZE;
|
||||||
continue;
|
mask = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr + ptr, BYTEORDER).compare(VectorOperators.EQ, delim).toLong();
|
||||||
}
|
}
|
||||||
final long cityLength = ptr + Long.numberOfTrailingZeros(mask);
|
final long cityLength = ptr + Long.numberOfTrailingZeros(mask);
|
||||||
final long tempAddr = cityAddr + cityLength + 1;
|
final long tempAddr = cityAddr + cityLength + 1;
|
||||||
|
ptr = 0;
|
||||||
|
|
||||||
// compute hash table index
|
// compute hash table index
|
||||||
int index;
|
int index;
|
||||||
@ -222,67 +213,79 @@ public class CalculateAverage_yourwass {
|
|||||||
& 0xFFFF;
|
& 0xFFFF;
|
||||||
else
|
else
|
||||||
index = (UNSAFE.getByte(cityAddr) << 8) & 0xFF00;
|
index = (UNSAFE.getByte(cityAddr) << 8) & 0xFF00;
|
||||||
|
|
||||||
// resolve collisions with linear probing
|
// resolve collisions with linear probing
|
||||||
// use vector api here also, but only if city name fits in one vector length, for faster default case
|
// use vector api here also, but only if city name fits in one vector length, for faster default case
|
||||||
Record record = results[index];
|
long record = threadResults + index * RECORDSIZE;
|
||||||
|
long recordCityLength = UNSAFE.getLong(record);
|
||||||
if (cityLength <= VECTORBYTESIZE) {
|
if (cityLength <= VECTORBYTESIZE) {
|
||||||
ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER);
|
while (recordCityLength > 0) {
|
||||||
while (record != null) {
|
if (cityLength == recordCityLength) {
|
||||||
if (cityLength == record.cityLength) {
|
long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, UNSAFE.getLong(record + 8), BYTEORDER)
|
||||||
long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, record.cityAddr, BYTEORDER)
|
|
||||||
.compare(VectorOperators.EQ, parsed).toLong();
|
.compare(VectorOperators.EQ, parsed).toLong();
|
||||||
if (Long.numberOfTrailingZeros(~sameMask) >= cityLength)
|
if (Long.numberOfTrailingZeros(~sameMask) >= cityLength)
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
record = results[++index];
|
index++;
|
||||||
|
record = threadResults + index * RECORDSIZE;
|
||||||
|
recordCityLength = UNSAFE.getLong(record);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else { // slower normal case for city names with length > VECTORBYTESIZE
|
else { // slower normal case for city names with length > VECTORBYTESIZE
|
||||||
while (record != null && (cityLength != record.cityLength || citiesDiffer(record.cityAddr, cityAddr, cityLength)))
|
while (recordCityLength > 0 && (cityLength != recordCityLength || citiesDiffer(UNSAFE.getLong(record + 8), cityAddr, cityLength))) {
|
||||||
record = results[++index];
|
index++;
|
||||||
|
record = threadResults + index * RECORDSIZE;
|
||||||
|
recordCityLength = UNSAFE.getLong(record);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add record for new keys
|
// add record for new key
|
||||||
// TODO have to avoid memory allocations on hot path
|
if (recordCityLength == 0) {
|
||||||
if (record == null) {
|
UNSAFE.putLong(record, cityLength);
|
||||||
results[index] = new Record(cityAddr, cityLength);
|
UNSAFE.putLong(record + 8, cityAddr);
|
||||||
record = results[index];
|
UNSAFE.putInt(record + 16, 1000);
|
||||||
|
UNSAFE.putInt(record + 20, -1000);
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse temp with lookup tables
|
// parse temp with lookup tables
|
||||||
int temp;
|
int temp;
|
||||||
if (UNSAFE.getByte(tempAddr) == '-') {
|
if (UNSAFE.getByte(tempAddr) == '-') {
|
||||||
temp = -lookupDecimal[UNSAFE.getShort(tempAddr + 1)] - lookupFraction[UNSAFE.getShort(tempAddr + 3)];
|
temp = -lookupDecimal[UNSAFE.getShort(tempAddr + 1)] - lookupFraction[UNSAFE.getShort(tempAddr + 3)];
|
||||||
nextCityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)];
|
cityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)];
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
temp = lookupDecimal[UNSAFE.getShort(tempAddr)] + lookupFraction[UNSAFE.getShort(tempAddr + 2)];
|
temp = lookupDecimal[UNSAFE.getShort(tempAddr)] + lookupFraction[UNSAFE.getShort(tempAddr + 2)];
|
||||||
nextCityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)];
|
cityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)];
|
||||||
}
|
}
|
||||||
cityAddr = nextCityAddr;
|
|
||||||
ptr = 0;
|
|
||||||
|
|
||||||
// merge record
|
// merge
|
||||||
if (temp < record.min)
|
if (temp < UNSAFE.getInt(record + 16))
|
||||||
record.min = temp;
|
UNSAFE.putInt(record + 16, temp);
|
||||||
if (temp > record.max)
|
if (temp > UNSAFE.getInt(record + 20))
|
||||||
record.max = temp;
|
UNSAFE.putInt(record + 20, temp);
|
||||||
record.sum += temp;
|
UNSAFE.putLong(record + 24, UNSAFE.getLong(record + 24) + temp);
|
||||||
record.count += 1;
|
UNSAFE.putInt(record + 32, UNSAFE.getInt(record + 32) + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// create strings from raw data
|
// create strings from raw data
|
||||||
// TODO should avoid this copy
|
// and aggregate results onto TreeMap
|
||||||
|
int idx = 0;
|
||||||
byte b[] = new byte[100];
|
byte b[] = new byte[100];
|
||||||
|
_mutex.lock();
|
||||||
for (int i = 0; i < MAXINDEX; i++) {
|
for (int i = 0; i < MAXINDEX; i++) {
|
||||||
Record r = results[i];
|
if (UNSAFE.getLong(threadResults + i * RECORDSIZE) == 0)
|
||||||
if (r == null)
|
|
||||||
continue;
|
continue;
|
||||||
UNSAFE.copyMemory(null, r.cityAddr, b, Unsafe.ARRAY_BYTE_BASE_OFFSET, r.cityLength);
|
final long recordAddress = threadResults + i * RECORDSIZE;
|
||||||
r.city = new String(b, 0, (int) r.cityLength, StandardCharsets.UTF_8);
|
|
||||||
}
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
results[idx] = new Record(UNSAFE.getLong(recordAddress + 8), UNSAFE.getLong(recordAddress));
|
||||||
|
results[idx].min = UNSAFE.getInt(recordAddress + 16);
|
||||||
|
results[idx].max = UNSAFE.getInt(recordAddress + 20);
|
||||||
|
results[idx].sum = UNSAFE.getLong(recordAddress + 24);
|
||||||
|
results[idx].count = UNSAFE.getInt(recordAddress + 32);
|
||||||
|
UNSAFE.copyMemory(null, UNSAFE.getLong(recordAddress + 8), b, Unsafe.ARRAY_BYTE_BASE_OFFSET, UNSAFE.getLong(recordAddress));
|
||||||
|
final Record record = results[idx];
|
||||||
|
aggregateResults.compute(new String(b, 0, (int) results[idx].cityLength, StandardCharsets.UTF_8), (k, v) -> (v == null) ? record : v.merge(record));
|
||||||
|
idx++;
|
||||||
|
}
|
||||||
|
_mutex.unlock();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user