Use Arena MemorySegments rather than ByteBuffers. (#505)
This commit is contained in:
		| @@ -15,5 +15,5 @@ | ||||
| #  limitations under the License. | ||||
| # | ||||
|  | ||||
| JAVA_OPTS="" | ||||
| JAVA_OPTS="--enable-preview" | ||||
| java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ebarlas | ||||
|   | ||||
| @@ -18,9 +18,9 @@ package dev.morling.onebrc; | ||||
| import sun.misc.Unsafe; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.nio.BufferUnderflowException; | ||||
| import java.nio.ByteBuffer; | ||||
| import java.nio.ByteOrder; | ||||
| import java.lang.foreign.Arena; | ||||
| import java.lang.foreign.MemorySegment; | ||||
| import java.lang.foreign.ValueLayout; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.nio.file.Paths; | ||||
| @@ -30,9 +30,13 @@ import java.util.TreeMap; | ||||
|  | ||||
| public class CalculateAverage_ebarlas { | ||||
|  | ||||
|     private static final Arena ARENA = Arena.global(); | ||||
|  | ||||
|     private static final int MAX_KEY_SIZE = 100; | ||||
|     private static final int MAX_VAL_SIZE = 5; // -dd.d | ||||
|     private static final int MAX_LINE_SIZE = MAX_KEY_SIZE + MAX_VAL_SIZE + 2; // key, semicolon, val, newline | ||||
|     private static final int HASH_FACTOR = 433; | ||||
|     private static final int HASH_TBL_SIZE = 16_383; // range of allowed hash values, inclusive | ||||
|     private static final int HASH_TBL_SIZE = 32_767; // range of allowed hash values, inclusive | ||||
|  | ||||
|     private static final Unsafe UNSAFE = makeUnsafe(); | ||||
|  | ||||
| @@ -50,7 +54,7 @@ public class CalculateAverage_ebarlas { | ||||
|     public static void main(String[] args) throws IOException, InterruptedException { | ||||
|         var path = Paths.get("measurements.txt"); | ||||
|         var channel = FileChannel.open(path, StandardOpenOption.READ); | ||||
|         var numPartitions = (int) Math.max((channel.size() / Integer.MAX_VALUE) + 1, Runtime.getRuntime().availableProcessors()); | ||||
|         var numPartitions = Runtime.getRuntime().availableProcessors(); | ||||
|         var partitionSize = channel.size() / numPartitions; | ||||
|         var partitions = new Partition[numPartitions]; | ||||
|         var threads = new Thread[numPartitions]; | ||||
| @@ -63,8 +67,8 @@ public class CalculateAverage_ebarlas { | ||||
|             var pSize = pEnd - pStart; | ||||
|             Runnable r = () -> { | ||||
|                 try { | ||||
|                     var buffer = channel.map(FileChannel.MapMode.READ_ONLY, pStart, pSize).order(ByteOrder.LITTLE_ENDIAN); | ||||
|                     partitions[pIdx] = processBuffer(buffer, pIdx == 0); | ||||
|                     var ms = channel.map(FileChannel.MapMode.READ_ONLY, pStart, pSize, ARENA); | ||||
|                     partitions[pIdx] = processSegment(ms, pIdx == 0, pIdx == numPartitions - 1); | ||||
|                 } | ||||
|                 catch (IOException e) { | ||||
|                     throw new RuntimeException(e); | ||||
| @@ -142,7 +146,7 @@ public class CalculateAverage_ebarlas { | ||||
|             var merged = mergeFooterAndHeader(pPrev.footer, pNext.header); | ||||
|             if (merged != null && merged.length != 0) { | ||||
|                 if (merged[merged.length - 1] == '\n') { // fold into prev partition | ||||
|                     doProcessBuffer(ByteBuffer.wrap(merged).order(ByteOrder.LITTLE_ENDIAN), true, pPrev.stats); | ||||
|                     doProcessSegment(ARENA.allocateArray(ValueLayout.JAVA_BYTE, merged), 0, pPrev.stats, true); | ||||
|                 } | ||||
|                 else { // no newline appeared in partition, carry forward | ||||
|                     pNext.footer = merged; | ||||
| @@ -164,93 +168,121 @@ public class CalculateAverage_ebarlas { | ||||
|         return merged; | ||||
|     } | ||||
|  | ||||
|     private static Partition processBuffer(ByteBuffer buffer, boolean first) { | ||||
|         return doProcessBuffer(buffer, first, new Stats[HASH_TBL_SIZE + 1]); | ||||
|     private static Partition processSegment(MemorySegment ms, boolean first, boolean last) { | ||||
|         var stats = new Stats[HASH_TBL_SIZE + 1]; // vals range from [0, size] inclusive | ||||
|         var header = first ? null : readHeader(ms); | ||||
|         var keyStart = doProcessSegment(ms, header == null ? 0 : header.offset, stats, last); // last segment is complete | ||||
|         var footer = keyStart < ms.byteSize() ? readFooter(ms, keyStart) : null; | ||||
|         return new Partition(header == null ? null : header.data, footer, stats); | ||||
|     } | ||||
|  | ||||
|     private static Partition doProcessBuffer(ByteBuffer buffer, boolean first, Stats[] stats) { | ||||
|         var header = first ? null : readHeader(buffer); | ||||
|         var keyStart = reallyDoProcessBuffer(buffer, stats); | ||||
|         var footer = keyStart < buffer.limit() ? readFooter(buffer, keyStart) : null; | ||||
|         return new Partition(header, footer, stats); | ||||
|     } | ||||
|  | ||||
|     private static int reallyDoProcessBuffer(ByteBuffer buffer, Stats[] stats) { | ||||
|         long keyBaseAddr = UNSAFE.allocateMemory(MAX_KEY_SIZE); | ||||
|         int keyStart = 0; // start of key in buffer used for footer calc | ||||
|         try { // abort with exception to allow optimistic line processing | ||||
|             while (true) { // one line per iteration | ||||
|                 keyStart = buffer.position(); // preserve line start | ||||
|                 int keyHash = 0; // key hash code | ||||
|                 long keyAddr = keyBaseAddr; // address for next int | ||||
|                 int keyArrLen = 0; // number of key 4-byte ints | ||||
|                 int keyLastBytes; // occupancy in last byte (1, 2, 3, or 4) | ||||
|                 int val; // temperature value | ||||
|                 while (true) { | ||||
|                     int n = buffer.getInt(); | ||||
|                     byte b0 = (byte) (n & 0xFF); | ||||
|                     byte b1 = (byte) ((n >> 8) & 0xFF); | ||||
|                     byte b2 = (byte) ((n >> 16) & 0xFF); | ||||
|                     byte b3 = (byte) ((n >> 24) & 0xFF); | ||||
|                     if (b0 == ';') { // ...;1.1 | ||||
|                         val = getVal(buffer, b1, b2, b3, buffer.get()); | ||||
|                         keyLastBytes = 4; | ||||
|                         break; | ||||
|                     } | ||||
|                     else if (b1 == ';') { // ...a;1.1 | ||||
|                         val = getVal(buffer, b2, b3, buffer.get(), buffer.get()); | ||||
|                         UNSAFE.putInt(keyAddr, b0); | ||||
|                         keyLastBytes = 1; | ||||
|                         keyArrLen++; | ||||
|                         keyHash = HASH_FACTOR * keyHash + b0; | ||||
|                         break; | ||||
|                     } | ||||
|                     else if (b2 == ';') { // ...ab;1.1 | ||||
|                         val = getVal(buffer, b3, buffer.get(), buffer.get(), buffer.get()); | ||||
|                         UNSAFE.putInt(keyAddr, n & 0x0000FFFF); | ||||
|                         keyLastBytes = 2; | ||||
|                         keyArrLen++; | ||||
|                         keyHash = HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1; | ||||
|                         break; | ||||
|                     } | ||||
|                     else if (b3 == ';') { // ...abc;1.1 | ||||
|                         UNSAFE.putInt(keyAddr, n & 0x00FFFFFF); | ||||
|                         keyLastBytes = 3; | ||||
|                         keyArrLen++; | ||||
|                         keyHash = HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1) + b2; | ||||
|                         n = buffer.getInt(); | ||||
|                         b0 = (byte) (n & 0xFF); | ||||
|                         b1 = (byte) ((n >> 8) & 0xFF); | ||||
|                         b2 = (byte) ((n >> 16) & 0xFF); | ||||
|                         b3 = (byte) ((n >> 24) & 0xFF); | ||||
|                         val = getVal(buffer, b0, b1, b2, b3); | ||||
|                         break; | ||||
|                     } | ||||
|                     else { | ||||
|                         UNSAFE.putInt(keyAddr, n); | ||||
|                         keyArrLen++; | ||||
|                         keyAddr += 4; | ||||
|                         keyHash = HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1) + b2) + b3; | ||||
|                     } | ||||
|     private static long doProcessSegment(MemorySegment ms, long offset, Stats[] stats, boolean complete) { | ||||
|         long cursor = ms.address() + offset; | ||||
|         long keyBaseAddr = UNSAFE.allocateMemory(MAX_KEY_SIZE); // reusable target for current key data | ||||
|         long lineStart = cursor; // start of key in segment used for footer calc | ||||
|         long limit = ms.address() + (complete ? ms.byteSize() : ms.byteSize() - MAX_LINE_SIZE); // stop short of longest line, sweep up at the end | ||||
|         while (cursor < limit) { // one line per iteration | ||||
|             lineStart = cursor; // preserve line start | ||||
|             int keyHash = 0; // key hash code | ||||
|             long keyAddr = keyBaseAddr; // address for next int | ||||
|             int keyArrLen = 0; // number of key 4-byte ints | ||||
|             int keyLastBytes; // occupancy in last byte (1, 2, 3, or 4) | ||||
|             byte b0, b1, b2, b3; | ||||
|             while (true) { | ||||
|                 int n = UNSAFE.getInt(cursor); | ||||
|                 cursor += 4; | ||||
|                 b0 = (byte) (n & 0xFF); | ||||
|                 b1 = (byte) ((n >> 8) & 0xFF); | ||||
|                 b2 = (byte) ((n >> 16) & 0xFF); | ||||
|                 b3 = (byte) ((n >> 24) & 0xFF); | ||||
|                 if (b0 == ';') { // ...;1.1 | ||||
|                     keyLastBytes = 4; | ||||
|                     b0 = b1; | ||||
|                     b1 = b2; | ||||
|                     b2 = b3; | ||||
|                     b3 = (byte) (UNSAFE.getByte(cursor++) & 0xFF); | ||||
|                     break; | ||||
|                 } | ||||
|                 var idx = keyHash & HASH_TBL_SIZE; | ||||
|                 var st = stats[idx]; | ||||
|                 if (st == null) { // nothing in table, eagerly claim spot | ||||
|                     st = stats[idx] = newStats(keyBaseAddr, keyArrLen, keyLastBytes, keyHash); | ||||
|                 else if (b1 == ';') { // ...a;1.1 | ||||
|                     int k = n & 0xFF; | ||||
|                     UNSAFE.putInt(keyAddr, k); | ||||
|                     keyLastBytes = 1; | ||||
|                     keyArrLen++; | ||||
|                     keyHash = HASH_FACTOR * keyHash + b0; | ||||
|                     b0 = b2; | ||||
|                     b1 = b3; | ||||
|                     b2 = (byte) (UNSAFE.getByte(cursor++) & 0xFF); | ||||
|                     b3 = (byte) (UNSAFE.getByte(cursor++) & 0xFF); | ||||
|                     break; | ||||
|                 } | ||||
|                 else if (!equals(st.keyAddr, st.keyLen, keyBaseAddr, keyArrLen)) { | ||||
|                     st = findInTable(stats, keyHash, keyBaseAddr, keyArrLen, keyLastBytes); | ||||
|                 else if (b2 == ';') { // ...ab;1.1 | ||||
|                     int k = n & 0xFFFF; | ||||
|                     UNSAFE.putInt(keyAddr, k); | ||||
|                     keyLastBytes = 2; | ||||
|                     keyArrLen++; | ||||
|                     keyHash = HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1; | ||||
|                     b0 = b3; | ||||
|                     b1 = (byte) (UNSAFE.getByte(cursor++) & 0xFF); | ||||
|                     b2 = (byte) (UNSAFE.getByte(cursor++) & 0xFF); | ||||
|                     b3 = (byte) (UNSAFE.getByte(cursor++) & 0xFF); | ||||
|                     break; | ||||
|                 } | ||||
|                 else if (b3 == ';') { // ...abc;1.1 | ||||
|                     int k = n & 0xFFFFFF; | ||||
|                     UNSAFE.putInt(keyAddr, k); | ||||
|                     keyLastBytes = 3; | ||||
|                     keyArrLen++; | ||||
|                     keyHash = HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1) + b2; | ||||
|                     n = UNSAFE.getInt(cursor); | ||||
|                     cursor += 4; | ||||
|                     b0 = (byte) (n & 0xFF); | ||||
|                     b1 = (byte) ((n >> 8) & 0xFF); | ||||
|                     b2 = (byte) ((n >> 16) & 0xFF); | ||||
|                     b3 = (byte) ((n >> 24) & 0xFF); | ||||
|                     break; | ||||
|                 } | ||||
|                 else { | ||||
|                     UNSAFE.putInt(keyAddr, n); | ||||
|                     keyArrLen++; | ||||
|                     keyAddr += 4; | ||||
|                     keyHash = HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1) + b2) + b3; | ||||
|                 } | ||||
|                 st.min = Math.min(st.min, val); | ||||
|                 st.max = Math.max(st.max, val); | ||||
|                 st.sum += val; | ||||
|                 st.count++; | ||||
|             } | ||||
|             var idx = keyHash & HASH_TBL_SIZE; | ||||
|             var st = stats[idx]; | ||||
|             if (st == null) { // nothing in table, eagerly claim spot | ||||
|                 st = stats[idx] = newStats(keyBaseAddr, keyArrLen, keyLastBytes, keyHash); | ||||
|             } | ||||
|             else if (!equals(st.keyAddr, st.keyLen, keyBaseAddr, keyArrLen)) { | ||||
|                 st = findInTable(stats, keyHash, keyBaseAddr, keyArrLen, keyLastBytes); | ||||
|             } | ||||
|             int val; | ||||
|             if (b0 == '-') { | ||||
|                 if (b2 != '.') { // 6 bytes: -dd.dn | ||||
|                     var b = UNSAFE.getByte(cursor); | ||||
|                     cursor += 2; // adv beyond digit and newline | ||||
|                     val = -(((b1 - '0') * 10 + (b2 - '0')) * 10 + (b - '0')); | ||||
|                 } | ||||
|                 else { // 5 bytes: -d.dn | ||||
|                     cursor++; // newline | ||||
|                     val = -((b1 - '0') * 10 + (b3 - '0')); | ||||
|                 } | ||||
|             } | ||||
|             else { | ||||
|                 if (b1 != '.') { // 5 bytes: dd.dn | ||||
|                     cursor++; // newline | ||||
|                     val = ((b0 - '0') * 10 + (b1 - '0')) * 10 + (b3 - '0'); | ||||
|                 } | ||||
|                 else { // 4 bytes: d.dn | ||||
|                     val = (b0 - '0') * 10 + (b2 - '0'); | ||||
|                 } | ||||
|             } | ||||
|             st.min = Math.min(st.min, val); | ||||
|             st.max = Math.max(st.max, val); | ||||
|             st.sum += val; | ||||
|             st.count++; | ||||
|         } | ||||
|         catch (BufferUnderflowException ignore) { | ||||
|  | ||||
|         } | ||||
|         return keyStart; | ||||
|         return lineStart - ms.address(); | ||||
|     } | ||||
|  | ||||
|     private static boolean equals(long key1, int len1, long key2, int len2) { | ||||
| @@ -261,7 +293,7 @@ public class CalculateAverage_ebarlas { | ||||
|             return UNSAFE.getLong(key1) == UNSAFE.getLong(key2); | ||||
|         } | ||||
|         if (len1 == 3) { | ||||
|             return UNSAFE.getInt(key1) == UNSAFE.getInt(key2) && UNSAFE.getInt(key1 + 4) == UNSAFE.getInt(key2 + 4); | ||||
|             return UNSAFE.getLong(key1) == UNSAFE.getLong(key2) && UNSAFE.getInt(key1 + 8) == UNSAFE.getInt(key2 + 8); | ||||
|         } | ||||
|         if (len1 == 1) { | ||||
|             return UNSAFE.getInt(key1) == UNSAFE.getInt(key2); | ||||
| @@ -278,29 +310,6 @@ public class CalculateAverage_ebarlas { | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     private static int getVal(ByteBuffer buffer, byte b0, byte b1, byte b2, byte b3) { | ||||
|         if (b0 == '-') { | ||||
|             if (b2 != '.') { // 6 bytes: -dd.dn | ||||
|                 var b = buffer.get(); | ||||
|                 buffer.get(); // newline | ||||
|                 return -(((b1 - '0') * 10 + (b2 - '0')) * 10 + (b - '0')); | ||||
|             } | ||||
|             else { // 5 bytes: -d.dn | ||||
|                 buffer.get(); // newline | ||||
|                 return -((b1 - '0') * 10 + (b3 - '0')); | ||||
|             } | ||||
|         } | ||||
|         else { | ||||
|             if (b1 != '.') { // 5 bytes: dd.dn | ||||
|                 buffer.get(); // newline | ||||
|                 return ((b0 - '0') * 10 + (b1 - '0')) * 10 + (b3 - '0'); | ||||
|             } | ||||
|             else { // 4 bytes: d.dn | ||||
|                 return (b0 - '0') * 10 + (b2 - '0'); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static Stats findInTable(Stats[] stats, int hash, long keyAddr, int keyLen, int keyLastBytes) { // open-addressing scan | ||||
|         var idx = hash & HASH_TBL_SIZE; | ||||
|         var st = stats[idx]; | ||||
| @@ -321,18 +330,26 @@ public class CalculateAverage_ebarlas { | ||||
|         return new Stats(k, keyLen, keyLastBytes, hash); | ||||
|     } | ||||
|  | ||||
|     private static byte[] readFooter(ByteBuffer buffer, int lineStart) { // read from line start to current pos (end-of-input) | ||||
|         var footer = new byte[buffer.limit() - lineStart]; | ||||
|         buffer.get(lineStart, footer, 0, footer.length); | ||||
|     private static byte[] readFooter(MemorySegment ms, long offset) { // read from line start to current pos (end-of-input) | ||||
|         var footer = new byte[(int) (ms.byteSize() - offset)]; | ||||
|         for (int i = 0; i < footer.length; i++) { | ||||
|             footer[i] = ms.get(ValueLayout.JAVA_BYTE, offset + i); | ||||
|         } | ||||
|         return footer; | ||||
|     } | ||||
|  | ||||
|     private static byte[] readHeader(ByteBuffer buffer) { // read up to and including first newline (or end-of-input) | ||||
|         while (buffer.hasRemaining() && buffer.get() != '\n') | ||||
|     private static ByteArrayOffset readHeader(MemorySegment ms) { // read up to and including first newline (or end-of-input) | ||||
|         long offset = 0; | ||||
|         while (offset < ms.byteSize() && ms.get(ValueLayout.JAVA_BYTE, offset++) != '\n') | ||||
|             ; | ||||
|         var header = new byte[buffer.position()]; | ||||
|         buffer.get(0, header, 0, header.length); | ||||
|         return header; | ||||
|         var header = new byte[(int) offset]; | ||||
|         for (int i = 0; i < offset; i++) { | ||||
|             header[i] = ms.get(ValueLayout.JAVA_BYTE, i); | ||||
|         } | ||||
|         return new ByteArrayOffset(header, offset); | ||||
|     } | ||||
|  | ||||
|     record ByteArrayOffset(byte[] data, long offset) { | ||||
|     } | ||||
|  | ||||
|     private static class Partition { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user