Version 4 - roman-r-m (#484)

* Version 3

* trying to optimize memory access (-0.2s)

- use smaller segments confined to thread
- unload in parallel

* Only call MemorySegment.address() once (~200ms)
This commit is contained in:
Roman Musin 2024-01-19 16:20:57 +00:00 committed by GitHub
parent fefe326a14
commit ec27a47ce1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -71,10 +71,15 @@ public class CalculateAverage_roman_r_m {
private final long end;
private long offset;
public Worker(MemorySegment ms, long start, long end) {
this.ms = ms.asSlice(start, end - start);
this.offset = 0;
this.end = end - start;
public Worker(FileChannel channel, long start, long end) {
try {
this.ms = channel.map(FileChannel.MapMode.READ_ONLY, start, end - start, Arena.ofConfined());
this.offset = ms.address();
this.end = ms.address() + end - start;
}
catch (Exception e) {
throw new RuntimeException(e);
}
}
private void parseName(ByteString station) {
@ -82,7 +87,7 @@ public class CalculateAverage_roman_r_m {
long pos = -1;
while (end - offset > 8) {
long next = UNSAFE.getLong(ms.address() + offset);
long next = UNSAFE.getLong(offset);
pos = find(next, SEMICOLON_MASK);
if (pos >= 0) {
offset += pos;
@ -93,7 +98,7 @@ public class CalculateAverage_roman_r_m {
}
}
if (pos < 0) {
while (UNSAFE.getByte(ms.address() + offset++) != ';') {
while (UNSAFE.getByte(offset++) != ';') {
}
offset--;
}
@ -107,7 +112,7 @@ public class CalculateAverage_roman_r_m {
}
long parseNumberFast() {
long encodedVal = UNSAFE.getLong(ms.address() + offset);
long encodedVal = UNSAFE.getLong(offset);
var len = find(encodedVal, LINE_END_MASK);
offset += len + 1;
@ -127,12 +132,12 @@ public class CalculateAverage_roman_r_m {
}
long parseNumberSlow() {
long val = UNSAFE.getByte(ms.address() + offset++) - '0';
long val = UNSAFE.getByte(offset++) - '0';
byte b;
while ((b = UNSAFE.getByte(ms.address() + offset++)) != '.') {
while ((b = UNSAFE.getByte(offset++)) != '.') {
val = val * 10 + (b - '0');
}
b = UNSAFE.getByte(ms.address() + offset);
b = UNSAFE.getByte(offset);
val = val * 10 + (b - '0');
offset += 2;
return val;
@ -140,7 +145,7 @@ public class CalculateAverage_roman_r_m {
long parseNumber() {
long val;
int neg = 1 - Integer.bitCount(UNSAFE.getByte(ms.address() + offset) & 0x10);
int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10);
offset += neg;
if (end - offset > 8) {
@ -178,18 +183,27 @@ public class CalculateAverage_roman_r_m {
long fileSize = new File(FILE).length();
var channel = FileChannel.open(Paths.get(FILE));
MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofAuto());
MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofConfined());
int numThreads = fileSize > Integer.MAX_VALUE ? Runtime.getRuntime().availableProcessors() : 1;
long chunk = fileSize / numThreads;
var bounds = IntStream.range(0, numThreads).mapToLong(i -> {
boolean lastChunk = i == numThreads - 1;
return lastChunk ? fileSize : nextNewline((i + 1) * chunk, ms);
}).toArray();
ms.unload();
var result = IntStream.range(0, numThreads)
.parallel()
.mapToObj(i -> {
boolean lastChunk = i == numThreads - 1;
long chunkStart = i == 0 ? 0 : nextNewline(i * chunk, ms) + 1;
long chunkEnd = lastChunk ? fileSize : nextNewline((i + 1) * chunk, ms);
return new Worker(ms, chunkStart, chunkEnd).run();
long start = i == 0 ? 0 : bounds[i - 1] + 1;
long end = bounds[i];
Worker worker = new Worker(channel, start, end);
var res = worker.run();
worker.ms.unload();
return res;
}).reduce((m1, m2) -> {
m2.forEach((k, v) -> m1.merge(k, v, ResultRow::merge));
return m1;
@ -212,7 +226,7 @@ public class CalculateAverage_roman_r_m {
@Override
public String toString() {
var bytes = new byte[len];
UNSAFE.copyMemory(null, ms.address() + offset, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, len);
UNSAFE.copyMemory(null, offset, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, len);
return new String(bytes, 0, len);
}
@ -238,23 +252,21 @@ public class CalculateAverage_roman_r_m {
int i = 0;
long base1 = ms.address() + offset;
long base2 = ms.address() + that.offset;
for (; i + 7 < len; i += 8) {
long l1 = UNSAFE.getLong(base1 + i);
long l2 = UNSAFE.getLong(base2 + i);
long l1 = UNSAFE.getLong(offset + i);
long l2 = UNSAFE.getLong(that.offset + i);
if (l1 != l2) {
return false;
}
}
if (len >= 8) {
long l1 = UNSAFE.getLong(base1 + len - 8);
long l2 = UNSAFE.getLong(base2 + len - 8);
long l1 = UNSAFE.getLong(offset + len - 8);
long l2 = UNSAFE.getLong(that.offset + len - 8);
return l1 == l2;
}
for (; i < len; i++) {
byte i1 = UNSAFE.getByte(base1 + i);
byte i2 = UNSAFE.getByte(base2 + i);
byte i1 = UNSAFE.getByte(offset + i);
byte i2 = UNSAFE.getByte(that.offset + i);
if (i1 != i2) {
return false;
}
@ -265,7 +277,7 @@ public class CalculateAverage_roman_r_m {
@Override
public int hashCode() {
if (hash == 0) {
long h = UNSAFE.getLong(ms.address() + offset);
long h = UNSAFE.getLong(offset);
h = Long.reverseBytes(h) >>> (8 * Math.max(0, 8 - len));
hash = (int) (h ^ (h >>> 32));
}