roman-r-m improved version (#368)

* remove unneeded check

* slightly improved hash code perf

* Use unsafe to access memory + untangle the code a bit

* Adhoc cache that works a bit better

* Store station names as offset into the memory segment + length; slightly change how the hash is calculated
This commit is contained in:
Roman Musin 2024-01-13 10:46:52 +00:00 committed by GitHub
parent 062f424c10
commit 092132afe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,11 +15,13 @@
*/ */
package dev.morling.onebrc; package dev.morling.onebrc;
import sun.misc.Unsafe;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.lang.foreign.Arena; import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment; import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout; import java.lang.foreign.ValueLayout;
import java.lang.reflect.Field;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.ArrayList; import java.util.ArrayList;
@ -34,6 +36,8 @@ public class CalculateAverage_roman_r_m {
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static MemorySegment ms; private static MemorySegment ms;
private static Unsafe UNSAFE;
// based on http://0x80.pl/notesen/2023-03-06-swar-find-any.html // based on http://0x80.pl/notesen/2023-03-06-swar-find-any.html
static long hasZeroByte(long l) { static long hasZeroByte(long l) {
return ((l - 0x0101010101010101L) & ~(l) & 0x8080808080808080L); return ((l - 0x0101010101010101L) & ~(l) & 0x8080808080808080L);
@ -67,7 +71,11 @@ public class CalculateAverage_roman_r_m {
return start + i; return start + i;
} }
public static void main(String[] args) throws IOException { public static void main(String[] args) throws Exception {
Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
UNSAFE = (Unsafe) f.get(null);
long fileSize = new File(FILE).length(); long fileSize = new File(FILE).length();
var channel = FileChannel.open(Paths.get(FILE)); var channel = FileChannel.open(Paths.get(FILE));
@ -88,34 +96,29 @@ public class CalculateAverage_roman_r_m {
long offset = chunkStart; long offset = chunkStart;
while (offset < chunkEnd) { while (offset < chunkEnd) {
long start = offset; long start = offset;
long pos; long pos = -1;
if (!lastChunk || chunkEnd - offset >= 8) { while (chunkEnd - offset >= 8) {
long next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, offset); long next = UNSAFE.getLong(ms.address() + offset);
pos = find(next, SEMICOLON_MASK);
while ((pos = find(next, SEMICOLON_MASK)) < 0) { if (pos >= 0) {
offset += pos;
break;
}
else {
offset += 8; offset += 8;
if (!lastChunk || fileSize - offset >= 8) {
next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, offset);
}
else {
while (ms.get(ValueLayout.JAVA_BYTE, offset + pos) != ';') {
pos++;
}
break;
}
} }
} }
else { if (pos < 0) {
pos = 0; while (UNSAFE.getByte(ms.address() + offset++) != ';') {
while (ms.get(ValueLayout.JAVA_BYTE, offset + pos) != ';') {
pos++;
} }
offset--;
} }
offset += pos;
int len = (int) (offset - start); int len = (int) (offset - start);
// TODO can we not copy and use a reference into the memory segment to perform table lookup? // TODO can we not copy and use a reference into the memory segment to perform table lookup?
MemorySegment.copy(ms, ValueLayout.JAVA_BYTE, start, station.buf, 0, len);
station.offset = start;
station.len = len; station.len = len;
station.hash = 0; station.hash = 0;
@ -124,7 +127,7 @@ public class CalculateAverage_roman_r_m {
long val; long val;
boolean neg; boolean neg;
if (!lastChunk || fileSize - offset >= 8) { if (!lastChunk || fileSize - offset >= 8) {
long encodedVal = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, offset); long encodedVal = UNSAFE.getLong(ms.address() + offset);
neg = (encodedVal & (byte) '-') == (byte) '-'; neg = (encodedVal & (byte) '-') == (byte) '-';
if (neg) { if (neg) {
encodedVal >>= 8; encodedVal >>= 8;
@ -143,16 +146,16 @@ public class CalculateAverage_roman_r_m {
} }
} }
else { else {
neg = ms.get(ValueLayout.JAVA_BYTE, offset) == '-'; neg = UNSAFE.getByte(ms.address() + offset) == '-';
if (neg) { if (neg) {
offset++; offset++;
} }
val = ms.get(ValueLayout.JAVA_BYTE, offset++) - '0'; val = UNSAFE.getByte(ms.address() + offset++) - '0';
byte b; byte b;
while ((b = ms.get(ValueLayout.JAVA_BYTE, offset++)) != '.') { while ((b = UNSAFE.getByte(ms.address() + offset++)) != '.') {
val = val * 10 + (b - '0'); val = val * 10 + (b - '0');
} }
b = ms.get(ValueLayout.JAVA_BYTE, offset); b = UNSAFE.getByte(ms.address() + offset);
val = val * 10 + (b - '0'); val = val * 10 + (b - '0');
offset += 2; offset += 2;
} }
@ -178,23 +181,22 @@ public class CalculateAverage_roman_r_m {
static final class ByteString { static final class ByteString {
private byte[] buf = new byte[100]; private long offset;
private int len = 0; private int len = 0;
private int hash = 0; private int hash = 0;
@Override @Override
public String toString() { public String toString() {
return new String(buf, 0, len); var bytes = new byte[len];
MemorySegment.copy(ms, ValueLayout.JAVA_BYTE, offset, bytes, 0, len);
return new String(bytes, 0, len);
} }
public ByteString copy() { public ByteString copy() {
var copy = new ByteString(); var copy = new ByteString();
copy.offset = this.offset;
copy.len = this.len; copy.len = this.len;
copy.hash = this.hash; copy.hash = this.hash;
if (copy.buf.length < this.buf.length) {
copy.buf = new byte[this.buf.length];
}
System.arraycopy(this.buf, 0, copy.buf, 0, this.len);
return copy; return copy;
} }
@ -210,22 +212,34 @@ public class CalculateAverage_roman_r_m {
if (len != that.len) if (len != that.len)
return false; return false;
// TODO use Vector int i = 0;
for (int i = 0; i < len; i++) {
if (buf[i] != that.buf[i]) { long base1 = ms.address() + offset;
long base2 = ms.address() + that.offset;
for (; i + 3 < len; i += 4) {
int i1 = UNSAFE.getInt(base1 + i);
int i2 = UNSAFE.getInt(base2 + i);
if (i1 != i2) {
return false;
}
}
for (; i < len; i++) {
byte i1 = UNSAFE.getByte(base1 + i);
byte i2 = UNSAFE.getByte(base2 + i);
if (i1 != i2) {
return false; return false;
} }
} }
return true; return true;
} }
@Override @Override
public int hashCode() { public int hashCode() {
if (hash == 0) { if (hash == 0) {
for (int i = 0; i < len; i++) { // not sure why but it seems to be working a bit better
hash = 31 * hash + (buf[i] & 255); hash = UNSAFE.getInt(ms.address() + offset);
} hash = hash >>> (8 * Math.max(0, 4 - len));
hash |= len;
} }
return hash; return hash;
} }