From ec27a47ce1226a653ab205db846e2ca84171bf5f Mon Sep 17 00:00:00 2001 From: Roman Musin <995612+roman-r-m@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:20:57 +0000 Subject: [PATCH] Version 4 - roman-r-m (#484) * Version 3 * trying to optimize memory access (-0.2s) - use smaller segments confined to thread - unload in parallel * Only call MemorySegment.address() once (~200ms) --- .../onebrc/CalculateAverage_roman_r_m.java | 64 +++++++++++-------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java index 2efb461..5c43824 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java @@ -71,10 +71,15 @@ public class CalculateAverage_roman_r_m { private final long end; private long offset; - public Worker(MemorySegment ms, long start, long end) { - this.ms = ms.asSlice(start, end - start); - this.offset = 0; - this.end = end - start; + public Worker(FileChannel channel, long start, long end) { + try { + this.ms = channel.map(FileChannel.MapMode.READ_ONLY, start, end - start, Arena.ofConfined()); + this.offset = ms.address(); + this.end = ms.address() + end - start; + } + catch (Exception e) { + throw new RuntimeException(e); + } } private void parseName(ByteString station) { @@ -82,7 +87,7 @@ public class CalculateAverage_roman_r_m { long pos = -1; while (end - offset > 8) { - long next = UNSAFE.getLong(ms.address() + offset); + long next = UNSAFE.getLong(offset); pos = find(next, SEMICOLON_MASK); if (pos >= 0) { offset += pos; @@ -93,7 +98,7 @@ public class CalculateAverage_roman_r_m { } } if (pos < 0) { - while (UNSAFE.getByte(ms.address() + offset++) != ';') { + while (UNSAFE.getByte(offset++) != ';') { } offset--; } @@ -107,7 +112,7 @@ public class CalculateAverage_roman_r_m { } long parseNumberFast() { - long encodedVal = UNSAFE.getLong(ms.address() + offset); + long encodedVal = UNSAFE.getLong(offset); var len = find(encodedVal, LINE_END_MASK); offset += len + 1; @@ -127,12 +132,12 @@ public class CalculateAverage_roman_r_m { } long parseNumberSlow() { - long val = UNSAFE.getByte(ms.address() + offset++) - '0'; + long val = UNSAFE.getByte(offset++) - '0'; byte b; - while ((b = UNSAFE.getByte(ms.address() + offset++)) != '.') { + while ((b = UNSAFE.getByte(offset++)) != '.') { val = val * 10 + (b - '0'); } - b = UNSAFE.getByte(ms.address() + offset); + b = UNSAFE.getByte(offset); val = val * 10 + (b - '0'); offset += 2; return val; @@ -140,7 +145,7 @@ public class CalculateAverage_roman_r_m { long parseNumber() { long val; - int neg = 1 - Integer.bitCount(UNSAFE.getByte(ms.address() + offset) & 0x10); + int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10); offset += neg; if (end - offset > 8) { @@ -178,18 +183,27 @@ public class CalculateAverage_roman_r_m { long fileSize = new File(FILE).length(); var channel = FileChannel.open(Paths.get(FILE)); - MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofAuto()); + MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofConfined()); int numThreads = fileSize > Integer.MAX_VALUE ? Runtime.getRuntime().availableProcessors() : 1; long chunk = fileSize / numThreads; + var bounds = IntStream.range(0, numThreads).mapToLong(i -> { + boolean lastChunk = i == numThreads - 1; + return lastChunk ? fileSize : nextNewline((i + 1) * chunk, ms); + }).toArray(); + + ms.unload(); + var result = IntStream.range(0, numThreads) .parallel() .mapToObj(i -> { - boolean lastChunk = i == numThreads - 1; - long chunkStart = i == 0 ? 0 : nextNewline(i * chunk, ms) + 1; - long chunkEnd = lastChunk ? fileSize : nextNewline((i + 1) * chunk, ms); - return new Worker(ms, chunkStart, chunkEnd).run(); + long start = i == 0 ? 0 : bounds[i - 1] + 1; + long end = bounds[i]; + Worker worker = new Worker(channel, start, end); + var res = worker.run(); + worker.ms.unload(); + return res; }).reduce((m1, m2) -> { m2.forEach((k, v) -> m1.merge(k, v, ResultRow::merge)); return m1; @@ -212,7 +226,7 @@ public class CalculateAverage_roman_r_m { @Override public String toString() { var bytes = new byte[len]; - UNSAFE.copyMemory(null, ms.address() + offset, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, len); + UNSAFE.copyMemory(null, offset, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, len); return new String(bytes, 0, len); } @@ -238,23 +252,21 @@ public class CalculateAverage_roman_r_m { int i = 0; - long base1 = ms.address() + offset; - long base2 = ms.address() + that.offset; for (; i + 7 < len; i += 8) { - long l1 = UNSAFE.getLong(base1 + i); - long l2 = UNSAFE.getLong(base2 + i); + long l1 = UNSAFE.getLong(offset + i); + long l2 = UNSAFE.getLong(that.offset + i); if (l1 != l2) { return false; } } if (len >= 8) { - long l1 = UNSAFE.getLong(base1 + len - 8); - long l2 = UNSAFE.getLong(base2 + len - 8); + long l1 = UNSAFE.getLong(offset + len - 8); + long l2 = UNSAFE.getLong(that.offset + len - 8); return l1 == l2; } for (; i < len; i++) { - byte i1 = UNSAFE.getByte(base1 + i); - byte i2 = UNSAFE.getByte(base2 + i); + byte i1 = UNSAFE.getByte(offset + i); + byte i2 = UNSAFE.getByte(that.offset + i); if (i1 != i2) { return false; } @@ -265,7 +277,7 @@ public class CalculateAverage_roman_r_m { @Override public int hashCode() { if (hash == 0) { - long h = UNSAFE.getLong(ms.address() + offset); + long h = UNSAFE.getLong(offset); h = Long.reverseBytes(h) >>> (8 * Math.max(0, 8 - len)); hash = (int) (h ^ (h >>> 32)); }