Version 3 (#455)

This commit is contained in:
Roman Musin 2024-01-17 17:07:56 +00:00 committed by GitHub
parent 1bbddaaaf6
commit 77872e197d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,17 +24,12 @@ import java.lang.foreign.ValueLayout;
import java.lang.reflect.Field;
import java.nio.channels.FileChannel;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.IntStream;
public class CalculateAverage_roman_r_m {
public static final int DOT_3_RD_BYTE_MASK = (byte) '.' << 16;
private static final String FILE = "./measurements.txt";
private static MemorySegment ms;
private static Unsafe UNSAFE;
@ -60,7 +55,7 @@ public class CalculateAverage_roman_r_m {
return match != 0 ? firstSetByteIndex(match) : -1;
}
static long nextNewline(long from) {
static long nextNewline(long from, MemorySegment ms) {
long start = from;
long i;
long next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, start);
@ -71,6 +66,110 @@ public class CalculateAverage_roman_r_m {
return start + i;
}
static class Worker {
private final MemorySegment ms;
private final long end;
private long offset;
public Worker(MemorySegment ms, long start, long end) {
this.ms = ms.asSlice(start, end - start);
this.offset = 0;
this.end = end - start;
}
private void parseName(ByteString station) {
long start = offset;
long pos = -1;
while (end - offset > 8) {
long next = UNSAFE.getLong(ms.address() + offset);
pos = find(next, SEMICOLON_MASK);
if (pos >= 0) {
offset += pos;
break;
}
else {
offset += 8;
}
}
if (pos < 0) {
while (UNSAFE.getByte(ms.address() + offset++) != ';') {
}
offset--;
}
int len = (int) (offset - start);
station.offset = start;
station.len = len;
station.hash = 0;
offset++;
}
long parseNumberFast() {
long encodedVal = UNSAFE.getLong(ms.address() + offset);
var len = find(encodedVal, LINE_END_MASK);
offset += len + 1;
encodedVal ^= broadcast((byte) 0x30);
long c0 = len == 4 ? 100 : 10;
long c1 = 10 * (len - 3);
long c2 = 4 - len;
long c3 = len - 3;
long a = (encodedVal & 0xFF) * c0;
long b = ((encodedVal & 0xFF00) >>> 8) * c1;
long c = ((encodedVal & 0xFF0000L) >>> 16) * c2;
long d = ((encodedVal & 0xFF000000L) >>> 24) * c3;
return a + b + c + d;
}
long parseNumberSlow() {
long val = UNSAFE.getByte(ms.address() + offset++) - '0';
byte b;
while ((b = UNSAFE.getByte(ms.address() + offset++)) != '.') {
val = val * 10 + (b - '0');
}
b = UNSAFE.getByte(ms.address() + offset);
val = val * 10 + (b - '0');
offset += 2;
return val;
}
long parseNumber() {
long val;
int neg = 1 - Integer.bitCount(UNSAFE.getByte(ms.address() + offset) & 0x10);
offset += neg;
if (end - offset > 8) {
val = parseNumberFast();
}
else {
val = parseNumberSlow();
}
val *= 1 - 2 * neg;
return val;
}
public TreeMap<String, ResultRow> run() {
var resultStore = new ResultStore();
var station = new ByteString(ms);
while (offset < end) {
parseName(station);
long val = parseNumber();
var a = resultStore.get(station);
a.min = Math.min(a.min, val);
a.max = Math.max(a.max, val);
a.sum += val;
a.count++;
}
return resultStore.toMap();
}
}
public static void main(String[] args) throws Exception {
Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
@ -79,98 +178,18 @@ public class CalculateAverage_roman_r_m {
long fileSize = new File(FILE).length();
var channel = FileChannel.open(Paths.get(FILE));
ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofAuto());
MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofAuto());
int numThreads = fileSize > Integer.MAX_VALUE ? Runtime.getRuntime().availableProcessors() : 1;
long chunk = fileSize / numThreads;
var result = IntStream.range(0, numThreads)
.parallel()
.mapToObj(i -> {
boolean lastChunk = i == numThreads - 1;
long chunkStart = i == 0 ? 0 : nextNewline(i * chunk) + 1;
long chunkEnd = lastChunk ? fileSize : nextNewline((i + 1) * chunk);
var resultStore = new ResultStore();
var station = new ByteString();
long offset = chunkStart;
while (offset < chunkEnd) {
long start = offset;
long pos = -1;
while (chunkEnd - offset >= 8) {
long next = UNSAFE.getLong(ms.address() + offset);
pos = find(next, SEMICOLON_MASK);
if (pos >= 0) {
offset += pos;
break;
}
else {
offset += 8;
}
}
if (pos < 0) {
while (UNSAFE.getByte(ms.address() + offset++) != ';') {
}
offset--;
}
int len = (int) (offset - start);
// TODO can we not copy and use a reference into the memory segment to perform table lookup?
station.offset = start;
station.len = len;
station.hash = 0;
offset++;
long val;
boolean neg;
if (!lastChunk || fileSize - offset >= 8) {
long encodedVal = UNSAFE.getLong(ms.address() + offset);
neg = (encodedVal & (byte) '-') == (byte) '-';
if (neg) {
encodedVal >>= 8;
offset++;
}
if ((encodedVal & DOT_3_RD_BYTE_MASK) == DOT_3_RD_BYTE_MASK) {
val = (encodedVal & 0xFF - 0x30) * 100 + (encodedVal >> 8 & 0xFF - 0x30) * 10 + (encodedVal >> 24 & 0xFF - 0x30);
offset += 5;
}
else {
// based on http://0x80.pl/articles/simd-parsing-int-sequences.html#parsing-and-conversion-of-signed-numbers
val = Long.compress(encodedVal, 0xFF00FFL) - 0x303030;
val = ((val * 2561) >> 8) & 0xff;
offset += 4;
}
}
else {
neg = UNSAFE.getByte(ms.address() + offset) == '-';
if (neg) {
offset++;
}
val = UNSAFE.getByte(ms.address() + offset++) - '0';
byte b;
while ((b = UNSAFE.getByte(ms.address() + offset++)) != '.') {
val = val * 10 + (b - '0');
}
b = UNSAFE.getByte(ms.address() + offset);
val = val * 10 + (b - '0');
offset += 2;
}
if (neg) {
val = -val;
}
var a = resultStore.get(station);
a.min = Math.min(a.min, val);
a.max = Math.max(a.max, val);
a.sum += val;
a.count++;
}
return resultStore.toMap();
long chunkStart = i == 0 ? 0 : nextNewline(i * chunk, ms) + 1;
long chunkEnd = lastChunk ? fileSize : nextNewline((i + 1) * chunk, ms);
return new Worker(ms, chunkStart, chunkEnd).run();
}).reduce((m1, m2) -> {
m2.forEach((k, v) -> m1.merge(k, v, ResultRow::merge));
return m1;
@ -181,19 +200,24 @@ public class CalculateAverage_roman_r_m {
static final class ByteString {
private final MemorySegment ms;
private long offset;
private int len = 0;
private int hash = 0;
ByteString(MemorySegment ms) {
this.ms = ms;
}
@Override
public String toString() {
var bytes = new byte[len];
MemorySegment.copy(ms, ValueLayout.JAVA_BYTE, offset, bytes, 0, len);
UNSAFE.copyMemory(null, ms.address() + offset, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, len);
return new String(bytes, 0, len);
}
public ByteString copy() {
var copy = new ByteString();
var copy = new ByteString(ms);
copy.offset = this.offset;
copy.len = this.len;
copy.hash = this.hash;
@ -216,13 +240,18 @@ public class CalculateAverage_roman_r_m {
long base1 = ms.address() + offset;
long base2 = ms.address() + that.offset;
for (; i + 3 < len; i += 4) {
int i1 = UNSAFE.getInt(base1 + i);
int i2 = UNSAFE.getInt(base2 + i);
if (i1 != i2) {
for (; i + 7 < len; i += 8) {
long l1 = UNSAFE.getLong(base1 + i);
long l2 = UNSAFE.getLong(base2 + i);
if (l1 != l2) {
return false;
}
}
if (len >= 8) {
long l1 = UNSAFE.getLong(base1 + len - 8);
long l2 = UNSAFE.getLong(base2 + len - 8);
return l1 == l2;
}
for (; i < len; i++) {
byte i1 = UNSAFE.getByte(base1 + i);
byte i2 = UNSAFE.getByte(base2 + i);
@ -236,10 +265,9 @@ public class CalculateAverage_roman_r_m {
@Override
public int hashCode() {
if (hash == 0) {
// not sure why but it seems to be working a bit better
hash = UNSAFE.getInt(ms.address() + offset);
hash = hash >>> (8 * Math.max(0, 4 - len));
hash |= len;
long h = UNSAFE.getLong(ms.address() + offset);
h = Long.reverseBytes(h) >>> (8 * Math.max(0, 8 - len));
hash = (int) (h ^ (h >>> 32));
}
return hash;
}
@ -269,25 +297,40 @@ public class CalculateAverage_roman_r_m {
}
static class ResultStore {
private final ArrayList<ResultRow> results = new ArrayList<>(10000);
private final Map<ByteString, Integer> indices = new HashMap<>(10000);
private static final int SIZE = 16384;
private final ByteString[] keys = new ByteString[SIZE];
private final ResultRow[] values = new ResultRow[SIZE];
ResultRow get(ByteString s) {
var idx = indices.get(s);
if (idx != null) {
return results.get(idx);
int h = s.hashCode();
int idx = (SIZE - 1) & h;
int i = 0;
while (keys[idx] != null && !keys[idx].equals(s)) {
i++;
idx = (idx + i * i) % SIZE;
}
ResultRow result;
if (keys[idx] == null) {
keys[idx] = s.copy();
result = new ResultRow();
values[idx] = result;
}
else {
ResultRow next = new ResultRow();
results.add(next);
indices.put(s.copy(), results.size() - 1);
return next;
result = values[idx];
// TODO see it it makes any difference
// keys[idx].offset = s.offset;
}
return result;
}
TreeMap<String, ResultRow> toMap() {
var result = new TreeMap<String, ResultRow>();
indices.forEach((name, idx) -> result.put(name.toString(), results.get(idx)));
for (int i = 0; i < SIZE; i++) {
if (keys[i] != null) {
result.put(keys[i].toString(), values[i]);
}
}
return result;
}
}