diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java index fa49a76..c869b7d 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java @@ -15,11 +15,13 @@ */ package dev.morling.onebrc; +import sun.misc.Unsafe; + import java.io.File; -import java.io.IOException; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; +import java.lang.reflect.Field; import java.nio.channels.FileChannel; import java.nio.file.Paths; import java.util.ArrayList; @@ -34,6 +36,8 @@ public class CalculateAverage_roman_r_m { private static final String FILE = "./measurements.txt"; private static MemorySegment ms; + private static Unsafe UNSAFE; + // based on http://0x80.pl/notesen/2023-03-06-swar-find-any.html static long hasZeroByte(long l) { return ((l - 0x0101010101010101L) & ~(l) & 0x8080808080808080L); @@ -67,7 +71,11 @@ public class CalculateAverage_roman_r_m { return start + i; } - public static void main(String[] args) throws IOException { + public static void main(String[] args) throws Exception { + Field f = Unsafe.class.getDeclaredField("theUnsafe"); + f.setAccessible(true); + UNSAFE = (Unsafe) f.get(null); + long fileSize = new File(FILE).length(); var channel = FileChannel.open(Paths.get(FILE)); @@ -88,34 +96,29 @@ public class CalculateAverage_roman_r_m { long offset = chunkStart; while (offset < chunkEnd) { long start = offset; - long pos; + long pos = -1; - if (!lastChunk || chunkEnd - offset >= 8) { - long next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, offset); - - while ((pos = find(next, SEMICOLON_MASK)) < 0) { + while (chunkEnd - offset >= 8) { + long next = UNSAFE.getLong(ms.address() + offset); + pos = find(next, SEMICOLON_MASK); + if (pos >= 0) { + offset += pos; + break; + } + else { offset += 8; - if (!lastChunk || fileSize - offset >= 8) { - next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, offset); - } - else { - while (ms.get(ValueLayout.JAVA_BYTE, offset + pos) != ';') { - pos++; - } - break; - } } } - else { - pos = 0; - while (ms.get(ValueLayout.JAVA_BYTE, offset + pos) != ';') { - pos++; + if (pos < 0) { + while (UNSAFE.getByte(ms.address() + offset++) != ';') { } + offset--; } - offset += pos; + int len = (int) (offset - start); // TODO can we not copy and use a reference into the memory segment to perform table lookup? - MemorySegment.copy(ms, ValueLayout.JAVA_BYTE, start, station.buf, 0, len); + + station.offset = start; station.len = len; station.hash = 0; @@ -124,7 +127,7 @@ public class CalculateAverage_roman_r_m { long val; boolean neg; if (!lastChunk || fileSize - offset >= 8) { - long encodedVal = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, offset); + long encodedVal = UNSAFE.getLong(ms.address() + offset); neg = (encodedVal & (byte) '-') == (byte) '-'; if (neg) { encodedVal >>= 8; @@ -143,16 +146,16 @@ public class CalculateAverage_roman_r_m { } } else { - neg = ms.get(ValueLayout.JAVA_BYTE, offset) == '-'; + neg = UNSAFE.getByte(ms.address() + offset) == '-'; if (neg) { offset++; } - val = ms.get(ValueLayout.JAVA_BYTE, offset++) - '0'; + val = UNSAFE.getByte(ms.address() + offset++) - '0'; byte b; - while ((b = ms.get(ValueLayout.JAVA_BYTE, offset++)) != '.') { + while ((b = UNSAFE.getByte(ms.address() + offset++)) != '.') { val = val * 10 + (b - '0'); } - b = ms.get(ValueLayout.JAVA_BYTE, offset); + b = UNSAFE.getByte(ms.address() + offset); val = val * 10 + (b - '0'); offset += 2; } @@ -178,23 +181,22 @@ public class CalculateAverage_roman_r_m { static final class ByteString { - private byte[] buf = new byte[100]; + private long offset; private int len = 0; private int hash = 0; @Override public String toString() { - return new String(buf, 0, len); + var bytes = new byte[len]; + MemorySegment.copy(ms, ValueLayout.JAVA_BYTE, offset, bytes, 0, len); + return new String(bytes, 0, len); } public ByteString copy() { var copy = new ByteString(); + copy.offset = this.offset; copy.len = this.len; copy.hash = this.hash; - if (copy.buf.length < this.buf.length) { - copy.buf = new byte[this.buf.length]; - } - System.arraycopy(this.buf, 0, copy.buf, 0, this.len); return copy; } @@ -210,22 +212,34 @@ public class CalculateAverage_roman_r_m { if (len != that.len) return false; - // TODO use Vector - for (int i = 0; i < len; i++) { - if (buf[i] != that.buf[i]) { + int i = 0; + + long base1 = ms.address() + offset; + long base2 = ms.address() + that.offset; + for (; i + 3 < len; i += 4) { + int i1 = UNSAFE.getInt(base1 + i); + int i2 = UNSAFE.getInt(base2 + i); + if (i1 != i2) { + return false; + } + } + for (; i < len; i++) { + byte i1 = UNSAFE.getByte(base1 + i); + byte i2 = UNSAFE.getByte(base2 + i); + if (i1 != i2) { return false; } } - return true; } @Override public int hashCode() { if (hash == 0) { - for (int i = 0; i < len; i++) { - hash = 31 * hash + (buf[i] & 255); - } + // not sure why but it seems to be working a bit better + hash = UNSAFE.getInt(ms.address() + offset); + hash = hash >>> (8 * Math.max(0, 4 - len)); + hash |= len; } return hash; }