diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java index c869b7d..2efb461 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java @@ -24,17 +24,12 @@ import java.lang.foreign.ValueLayout; import java.lang.reflect.Field; import java.nio.channels.FileChannel; import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; import java.util.TreeMap; import java.util.stream.IntStream; public class CalculateAverage_roman_r_m { - public static final int DOT_3_RD_BYTE_MASK = (byte) '.' << 16; private static final String FILE = "./measurements.txt"; - private static MemorySegment ms; private static Unsafe UNSAFE; @@ -60,7 +55,7 @@ public class CalculateAverage_roman_r_m { return match != 0 ? firstSetByteIndex(match) : -1; } - static long nextNewline(long from) { + static long nextNewline(long from, MemorySegment ms) { long start = from; long i; long next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, start); @@ -71,6 +66,110 @@ public class CalculateAverage_roman_r_m { return start + i; } + static class Worker { + private final MemorySegment ms; + private final long end; + private long offset; + + public Worker(MemorySegment ms, long start, long end) { + this.ms = ms.asSlice(start, end - start); + this.offset = 0; + this.end = end - start; + } + + private void parseName(ByteString station) { + long start = offset; + long pos = -1; + + while (end - offset > 8) { + long next = UNSAFE.getLong(ms.address() + offset); + pos = find(next, SEMICOLON_MASK); + if (pos >= 0) { + offset += pos; + break; + } + else { + offset += 8; + } + } + if (pos < 0) { + while (UNSAFE.getByte(ms.address() + offset++) != ';') { + } + offset--; + } + + int len = (int) (offset - start); + station.offset = start; + station.len = len; + station.hash = 0; + + offset++; + } + + long parseNumberFast() { + long encodedVal = UNSAFE.getLong(ms.address() + offset); + + var len = find(encodedVal, LINE_END_MASK); + offset += len + 1; + + encodedVal ^= broadcast((byte) 0x30); + + long c0 = len == 4 ? 100 : 10; + long c1 = 10 * (len - 3); + long c2 = 4 - len; + long c3 = len - 3; + long a = (encodedVal & 0xFF) * c0; + long b = ((encodedVal & 0xFF00) >>> 8) * c1; + long c = ((encodedVal & 0xFF0000L) >>> 16) * c2; + long d = ((encodedVal & 0xFF000000L) >>> 24) * c3; + + return a + b + c + d; + } + + long parseNumberSlow() { + long val = UNSAFE.getByte(ms.address() + offset++) - '0'; + byte b; + while ((b = UNSAFE.getByte(ms.address() + offset++)) != '.') { + val = val * 10 + (b - '0'); + } + b = UNSAFE.getByte(ms.address() + offset); + val = val * 10 + (b - '0'); + offset += 2; + return val; + } + + long parseNumber() { + long val; + int neg = 1 - Integer.bitCount(UNSAFE.getByte(ms.address() + offset) & 0x10); + offset += neg; + + if (end - offset > 8) { + val = parseNumberFast(); + } + else { + val = parseNumberSlow(); + } + val *= 1 - 2 * neg; + return val; + } + + public TreeMap run() { + var resultStore = new ResultStore(); + var station = new ByteString(ms); + + while (offset < end) { + parseName(station); + long val = parseNumber(); + var a = resultStore.get(station); + a.min = Math.min(a.min, val); + a.max = Math.max(a.max, val); + a.sum += val; + a.count++; + } + return resultStore.toMap(); + } + } + public static void main(String[] args) throws Exception { Field f = Unsafe.class.getDeclaredField("theUnsafe"); f.setAccessible(true); @@ -79,98 +178,18 @@ public class CalculateAverage_roman_r_m { long fileSize = new File(FILE).length(); var channel = FileChannel.open(Paths.get(FILE)); - ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofAuto()); + MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofAuto()); int numThreads = fileSize > Integer.MAX_VALUE ? Runtime.getRuntime().availableProcessors() : 1; long chunk = fileSize / numThreads; + var result = IntStream.range(0, numThreads) .parallel() .mapToObj(i -> { boolean lastChunk = i == numThreads - 1; - long chunkStart = i == 0 ? 0 : nextNewline(i * chunk) + 1; - long chunkEnd = lastChunk ? fileSize : nextNewline((i + 1) * chunk); - - var resultStore = new ResultStore(); - var station = new ByteString(); - - long offset = chunkStart; - while (offset < chunkEnd) { - long start = offset; - long pos = -1; - - while (chunkEnd - offset >= 8) { - long next = UNSAFE.getLong(ms.address() + offset); - pos = find(next, SEMICOLON_MASK); - if (pos >= 0) { - offset += pos; - break; - } - else { - offset += 8; - } - } - if (pos < 0) { - while (UNSAFE.getByte(ms.address() + offset++) != ';') { - } - offset--; - } - - int len = (int) (offset - start); - // TODO can we not copy and use a reference into the memory segment to perform table lookup? - - station.offset = start; - station.len = len; - station.hash = 0; - - offset++; - - long val; - boolean neg; - if (!lastChunk || fileSize - offset >= 8) { - long encodedVal = UNSAFE.getLong(ms.address() + offset); - neg = (encodedVal & (byte) '-') == (byte) '-'; - if (neg) { - encodedVal >>= 8; - offset++; - } - - if ((encodedVal & DOT_3_RD_BYTE_MASK) == DOT_3_RD_BYTE_MASK) { - val = (encodedVal & 0xFF - 0x30) * 100 + (encodedVal >> 8 & 0xFF - 0x30) * 10 + (encodedVal >> 24 & 0xFF - 0x30); - offset += 5; - } - else { - // based on http://0x80.pl/articles/simd-parsing-int-sequences.html#parsing-and-conversion-of-signed-numbers - val = Long.compress(encodedVal, 0xFF00FFL) - 0x303030; - val = ((val * 2561) >> 8) & 0xff; - offset += 4; - } - } - else { - neg = UNSAFE.getByte(ms.address() + offset) == '-'; - if (neg) { - offset++; - } - val = UNSAFE.getByte(ms.address() + offset++) - '0'; - byte b; - while ((b = UNSAFE.getByte(ms.address() + offset++)) != '.') { - val = val * 10 + (b - '0'); - } - b = UNSAFE.getByte(ms.address() + offset); - val = val * 10 + (b - '0'); - offset += 2; - } - - if (neg) { - val = -val; - } - - var a = resultStore.get(station); - a.min = Math.min(a.min, val); - a.max = Math.max(a.max, val); - a.sum += val; - a.count++; - } - return resultStore.toMap(); + long chunkStart = i == 0 ? 0 : nextNewline(i * chunk, ms) + 1; + long chunkEnd = lastChunk ? fileSize : nextNewline((i + 1) * chunk, ms); + return new Worker(ms, chunkStart, chunkEnd).run(); }).reduce((m1, m2) -> { m2.forEach((k, v) -> m1.merge(k, v, ResultRow::merge)); return m1; @@ -181,19 +200,24 @@ public class CalculateAverage_roman_r_m { static final class ByteString { + private final MemorySegment ms; private long offset; private int len = 0; private int hash = 0; + ByteString(MemorySegment ms) { + this.ms = ms; + } + @Override public String toString() { var bytes = new byte[len]; - MemorySegment.copy(ms, ValueLayout.JAVA_BYTE, offset, bytes, 0, len); + UNSAFE.copyMemory(null, ms.address() + offset, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, len); return new String(bytes, 0, len); } public ByteString copy() { - var copy = new ByteString(); + var copy = new ByteString(ms); copy.offset = this.offset; copy.len = this.len; copy.hash = this.hash; @@ -216,13 +240,18 @@ public class CalculateAverage_roman_r_m { long base1 = ms.address() + offset; long base2 = ms.address() + that.offset; - for (; i + 3 < len; i += 4) { - int i1 = UNSAFE.getInt(base1 + i); - int i2 = UNSAFE.getInt(base2 + i); - if (i1 != i2) { + for (; i + 7 < len; i += 8) { + long l1 = UNSAFE.getLong(base1 + i); + long l2 = UNSAFE.getLong(base2 + i); + if (l1 != l2) { return false; } } + if (len >= 8) { + long l1 = UNSAFE.getLong(base1 + len - 8); + long l2 = UNSAFE.getLong(base2 + len - 8); + return l1 == l2; + } for (; i < len; i++) { byte i1 = UNSAFE.getByte(base1 + i); byte i2 = UNSAFE.getByte(base2 + i); @@ -236,10 +265,9 @@ public class CalculateAverage_roman_r_m { @Override public int hashCode() { if (hash == 0) { - // not sure why but it seems to be working a bit better - hash = UNSAFE.getInt(ms.address() + offset); - hash = hash >>> (8 * Math.max(0, 4 - len)); - hash |= len; + long h = UNSAFE.getLong(ms.address() + offset); + h = Long.reverseBytes(h) >>> (8 * Math.max(0, 8 - len)); + hash = (int) (h ^ (h >>> 32)); } return hash; } @@ -269,25 +297,40 @@ public class CalculateAverage_roman_r_m { } static class ResultStore { - private final ArrayList results = new ArrayList<>(10000); - private final Map indices = new HashMap<>(10000); + private static final int SIZE = 16384; + private final ByteString[] keys = new ByteString[SIZE]; + private final ResultRow[] values = new ResultRow[SIZE]; ResultRow get(ByteString s) { - var idx = indices.get(s); - if (idx != null) { - return results.get(idx); + int h = s.hashCode(); + int idx = (SIZE - 1) & h; + + int i = 0; + while (keys[idx] != null && !keys[idx].equals(s)) { + i++; + idx = (idx + i * i) % SIZE; + } + ResultRow result; + if (keys[idx] == null) { + keys[idx] = s.copy(); + result = new ResultRow(); + values[idx] = result; } else { - ResultRow next = new ResultRow(); - results.add(next); - indices.put(s.copy(), results.size() - 1); - return next; + result = values[idx]; + // TODO see it it makes any difference + // keys[idx].offset = s.offset; } + return result; } TreeMap toMap() { var result = new TreeMap(); - indices.forEach((name, idx) -> result.put(name.toString(), results.get(idx))); + for (int i = 0; i < SIZE; i++) { + if (keys[i] != null) { + result.put(keys[i].toString(), values[i]); + } + } return result; } }