From 6bd2a21686718f1596c7ef01fe3313b4d419ec50 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Serkan=20=C3=96ZAL?= <sozal@catchpoint.com>
Date: Sun, 28 Jan 2024 13:56:30 +0300
Subject: [PATCH] serkan-ozal's 2nd submission with some minor improvements:
 (#612)

- use shared memory arena and region between worker threads
- reduce number of instructions slightly while processing file region
---
 calculate_average_serkan-ozal.sh              |   4 +-
 .../onebrc/CalculateAverage_serkan_ozal.java  | 125 ++++++++++--------
 2 files changed, 74 insertions(+), 55 deletions(-)

diff --git a/calculate_average_serkan-ozal.sh b/calculate_average_serkan-ozal.sh
index a903c1d..857979b 100755
--- a/calculate_average_serkan-ozal.sh
+++ b/calculate_average_serkan-ozal.sh
@@ -23,8 +23,10 @@ if [[ ! "$(uname -s)" = "Darwin" ]]; then
   JAVA_OPTS="$JAVA_OPTS -XX:+UseTransparentHugePages"
 fi
 
+CONFIGS="USE_SHARED_ARENA=true USE_SHARED_REGION=true CLOSE_STDOUT_ON_RESULT=true"
+
 #echo "Process started at $(date +%s%N | cut -b1-13)"
-eval "exec 3< <({ CLOSE_STDOUT_ON_RESULT=true USE_SHARED_ARENA=true java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_serkan_ozal; })"
+eval "exec 3< <({ $CONFIGS java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_serkan_ozal; })"
 read <&3 result
 echo -e "$result"
 #echo "Process finished at $(date +%s%N | cut -b1-13)"
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java b/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java
index b025383..8087919 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java
@@ -69,9 +69,10 @@ public class CalculateAverage_serkan_ozal {
     private static final boolean USE_VTHREADS = getBooleanConfig("USE_VTHREADS", false);
     private static final int VTHREAD_COUNT = getIntegerConfig("VTHREAD_COUNT", 1024);
     private static final int REGION_COUNT = getIntegerConfig("REGION_COUNT", -1);
-    private static final boolean USE_SHARED_ARENA = getBooleanConfig("USE_SHARED_ARENA", false);
+    private static final boolean USE_SHARED_ARENA = getBooleanConfig("USE_SHARED_ARENA", true);
+    private static final boolean USE_SHARED_REGION = getBooleanConfig("USE_SHARED_REGION", true);
     private static final int MAP_CAPACITY = getIntegerConfig("MAP_CAPACITY", 1 << 17);
-    private static final boolean CLOSE_STDOUT_ON_RESULT = getBooleanConfig("CLOSE_STDOUT_ON_RESULT", false);
+    private static final boolean CLOSE_STDOUT_ON_RESULT = getBooleanConfig("CLOSE_STDOUT_ON_RESULT", true);
     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
 
     // My dear old friend Unsafe
@@ -118,7 +119,11 @@ public class CalculateAverage_serkan_ozal {
             ExecutorService executor = USE_VTHREADS
                     ? Executors.newVirtualThreadPerTaskExecutor()
                     : Executors.newFixedThreadPool(concurrency, new RegionProcessorThreadFactory());
-
+            MemorySegment region = null;
+            if (USE_SHARED_REGION) {
+                arena = Arena.ofShared();
+                region = fc.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, arena);
+            }
             // Split whole file into regions and start region processors to handle those regions
             List<Future<Response>> futures = new ArrayList<>(regionCount);
             for (int i = 0; i < regionCount; i++) {
@@ -128,7 +133,7 @@ public class CalculateAverage_serkan_ozal {
                 long closestLineEndPos = (i < regionCount - 1)
                         ? findClosestLineEnd(fc, endPos, lineBuffer)
                         : fileSize;
-                Request request = new Request(fc, arena, startPos, closestLineEndPos, result);
+                Request request = new Request(fc, arena, region, startPos, closestLineEndPos, result);
                 RegionProcessor regionProcessor = createRegionProcessor(request);
                 Future<Response> future = executor.submit(regionProcessor);
                 futures.add(future);
@@ -230,19 +235,20 @@ public class CalculateAverage_serkan_ozal {
 
         private final FileChannel fc;
         private final Arena arena;
+        private final MemorySegment region;
         private final long start;
         private final long end;
         private final long size;
-        private final OpenMap map;
         private final Result result;
+        private OpenMap map;
 
         private RegionProcessor(Request request) {
             this.fc = request.fileChannel;
             this.arena = request.arena;
+            this.region = request.region;
             this.start = request.start;
             this.end = request.end;
             this.size = end - start;
-            this.map = new OpenMap();
             this.result = request.result;
         }
 
@@ -263,13 +269,21 @@ public class CalculateAverage_serkan_ozal {
         }
 
         private void processRegion() throws Exception {
+            // Create map in its own thread
+            this.map = new OpenMap();
+
             boolean arenaGiven = arena != null;
             // If no shared global memory arena is used, create and use its own local memory arena
             Arena a = arenaGiven ? arena : Arena.ofConfined();
             try {
-                MemorySegment region = fc.map(FileChannel.MapMode.READ_ONLY, start, size, a);
+                boolean regionGiven = region != null;
+                MemorySegment r = regionGiven
+                        ? region
+                        : fc.map(FileChannel.MapMode.READ_ONLY, start, size, a);
+                long regionStart = regionGiven ? (r.address() + start) : r.address();
+                long regionEnd = regionStart + size;
 
-                doProcessRegion(region);
+                doProcessRegion(r, r.address(), regionStart, regionEnd);
                 if (VERBOSE) {
                     System.out.println("[Processor-" + Thread.currentThread().getName() + "] Region processed at " + System.currentTimeMillis());
                 }
@@ -311,25 +325,23 @@ public class CalculateAverage_serkan_ozal {
             }
         }
 
-        private void doProcessRegion(MemorySegment region) {
-            final long regionAddress = region.address();
-            final long regionSize = region.byteSize();
+        private void doProcessRegion(MemorySegment region, long regionAddress, long regionStart, long regionEnd) {
             final int vectorSize = BYTE_SPECIES.vectorByteSize();
-            final long regionMainLimit = regionSize - MAX_LINE_LENGTH;
+            final long regionMainLimit = regionEnd - MAX_LINE_LENGTH;
 
-            int regionPtr;
+            long regionPtr;
 
             // Read and process region - main
-            for (regionPtr = 0; regionPtr < regionMainLimit;) {
-                regionPtr = doProcessLine(region, regionAddress, vectorSize, regionPtr);
+            for (regionPtr = regionStart; regionPtr < regionMainLimit;) {
+                regionPtr = doProcessLine(region, regionAddress, regionPtr, vectorSize);
             }
 
             // Read and process region - tail
-            for (int i = regionPtr, j = regionPtr; i < regionSize;) {
-                byte b = U.getByte(regionAddress + i);
+            for (long i = regionPtr, j = regionPtr; i < regionEnd;) {
+                byte b = U.getByte(i);
                 if (b == KEY_VALUE_SEPARATOR) {
-                    long baseOffset = map.putKey(null, regionAddress, j, i - j);
-                    i = extractValue(regionAddress, i + 1, map, baseOffset);
+                    long baseOffset = map.putKey(null, j, (int) (i - j));
+                    i = extractValue(i + 1, map, baseOffset);
                     j = i;
                 }
                 else {
@@ -338,42 +350,41 @@ public class CalculateAverage_serkan_ozal {
             }
         }
 
-        private int doProcessLine(MemorySegment region, long regionAddress, int vectorSize, int i) {
+        private long doProcessLine(MemorySegment region, long regionAddress, long regionPtr, int vectorSize) {
             // Find key/value separator
             ////////////////////////////////////////////////////////////////////////////////////////////////////////
-            int keyStartIdx = i;
+            long keyStartPtr = regionPtr;
 
             // Vectorized search for key/value separator
-            ByteVector keyVector = ByteVector.fromMemorySegment(BYTE_SPECIES, region, i, NATIVE_BYTE_ORDER);
+            ByteVector keyVector = ByteVector.fromMemorySegment(BYTE_SPECIES, region, regionPtr - regionAddress, NATIVE_BYTE_ORDER);
             int keyValueSepOffset = keyVector.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
             // Check whether key/value separator is found in the first vector (city name is <= vector size)
             if (keyValueSepOffset == vectorSize) {
-                i += vectorSize;
+                regionPtr += vectorSize;
                 keyValueSepOffset = 0;
-                for (; U.getByte(regionAddress + i) != KEY_VALUE_SEPARATOR; i++)
+                for (; U.getByte(regionPtr) != KEY_VALUE_SEPARATOR; regionPtr++)
                     ;
                 // I have tried vectorized search for key/value separator in the remaining part,
                 // but since majority (99%) of the city names <= 16 bytes
                 // and other a few longer city names (have length < 16 and <= 32) not close to 32 bytes,
                 // byte by byte search is better in terms of performance (according to my experiments) and simplicity.
             }
-            i += keyValueSepOffset;
-            int keyLength = i - keyStartIdx;
-            i++;
+            regionPtr += keyValueSepOffset;
+            int keyLength = (int) (regionPtr - keyStartPtr);
+            regionPtr++;
             ////////////////////////////////////////////////////////////////////////////////////////////////////////
 
             // Put key and get map offset to put value
-            long baseOffset = map.putKey(keyVector, regionAddress, keyStartIdx, keyLength);
+            long entryOffset = map.putKey(keyVector, keyStartPtr, keyLength);
 
             // Extract value, put it into map and return next position in the region to continue processing from there
-            return extractValue(regionAddress, i, map, baseOffset);
+            return extractValue(regionPtr, map, entryOffset);
         }
-
     }
 
     // Credits: merykitty
-    private static int extractValue(long regionAddress, int idx, OpenMap map, long baseOffset) {
-        long word = U.getLong(regionAddress + idx);
+    private static long extractValue(long regionPtr, OpenMap map, long entryOffset) {
+        long word = U.getLong(regionPtr);
         if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
             word = Long.reverseBytes(word);
         }
@@ -388,10 +399,10 @@ public class CalculateAverage_serkan_ozal {
         int value = (int) ((absValue ^ signed) - signed);
 
         // Put extracted value into map
-        map.putValue(baseOffset, value);
+        map.putValue(entryOffset, value);
 
         // Return new position
-        return idx + (decimalSepPos >>> 3) + 3;
+        return regionPtr + (decimalSepPos >>> 3) + 3;
     }
 
     /**
@@ -401,13 +412,16 @@ public class CalculateAverage_serkan_ozal {
 
         private final FileChannel fileChannel;
         private final Arena arena;
+        private final MemorySegment region;
         private final long start;
         private final long end;
         private final Result result;
 
-        private Request(FileChannel fileChannel, Arena arena, long start, long end, Result result) {
+        private Request(FileChannel fileChannel, Arena arena, MemorySegment region,
+                        long start, long end, Result result) {
             this.fileChannel = fileChannel;
             this.arena = arena;
+            this.region = region;
             this.start = start;
             this.end = end;
             this.result = result;
@@ -555,8 +569,7 @@ public class CalculateAverage_serkan_ozal {
             return (Integer.rotateLeft(x * seed, rotate) ^ y) * seed;
         }
 
-        private long putKey(ByteVector keyVector, long regionAddress, long keyStartIdx, int keyLength) {
-            long keyStartAddress = regionAddress + keyStartIdx;
+        private long putKey(ByteVector keyVector, long keyStartAddress, int keyLength) {
             // Calculate hash of key
             int keyHash = calculateKeyHash(keyStartAddress, keyLength);
             // and get the position of the entry in the linear map based on calculated hash
@@ -565,23 +578,23 @@ public class CalculateAverage_serkan_ozal {
             // Start searching from the calculated position
             // and continue until find an available slot in case of hash collision
             // TODO Prevent infinite loop if all the slots are in use for other keys
-            for (long baseOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET + (idx * ENTRY_SIZE);; baseOffset = (baseOffset + ENTRY_SIZE) & ENTRY_MASK) {
-                int keyStartOffset = (int) baseOffset + KEY_OFFSET;
-                int keySize = U.getInt(data, baseOffset + KEY_SIZE_OFFSET);
+            for (long entryOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET + (idx * ENTRY_SIZE);; entryOffset = (entryOffset + ENTRY_SIZE) & ENTRY_MASK) {
+                int keyStartOffset = (int) entryOffset + KEY_OFFSET;
+                int keySize = U.getInt(data, entryOffset + KEY_SIZE_OFFSET);
                 // Check whether current index is empty (no another key is inserted yet)
                 if (keySize == 0) {
                     // Initialize entry slot for new key
-                    U.putShort(data, baseOffset + MIN_VALUE_OFFSET, Short.MAX_VALUE);
-                    U.putShort(data, baseOffset + MAX_VALUE_OFFSET, Short.MIN_VALUE);
-                    U.putInt(data, baseOffset + KEY_SIZE_OFFSET, keyLength);
+                    U.putShort(data, entryOffset + MIN_VALUE_OFFSET, Short.MAX_VALUE);
+                    U.putShort(data, entryOffset + MAX_VALUE_OFFSET, Short.MIN_VALUE);
+                    U.putInt(data, entryOffset + KEY_SIZE_OFFSET, keyLength);
                     U.copyMemory(null, keyStartAddress, data, keyStartOffset, keyLength);
-                    return baseOffset;
+                    return entryOffset;
                 }
                 // Check for hash collision (hashes are same, but keys are different).
                 // If there is no collision (both hashes and keys are equals), return current slot's offset.
                 // Otherwise, continue iterating until find an available slot.
                 if (keySize == keyLength && keysEqual(keyVector, keyStartAddress, keyLength, keyStartOffset)) {
-                    return baseOffset;
+                    return entryOffset;
                 }
             }
         }
@@ -633,15 +646,19 @@ public class CalculateAverage_serkan_ozal {
             return wordA == wordB;
         }
 
-        private void putValue(long baseOffset, int value) {
-            U.putInt(data, baseOffset + COUNT_OFFSET,
-                    U.getInt(data, baseOffset + COUNT_OFFSET) + 1);
-            U.putShort(data, baseOffset + MIN_VALUE_OFFSET,
-                    (short) Math.min(value, U.getShort(data, baseOffset + MIN_VALUE_OFFSET)));
-            U.putShort(data, baseOffset + MAX_VALUE_OFFSET,
-                    (short) Math.max(value, U.getShort(data, baseOffset + MAX_VALUE_OFFSET)));
-            U.putLong(data, baseOffset + VALUE_SUM_OFFSET,
-                    value + U.getLong(data, baseOffset + VALUE_SUM_OFFSET));
+        private void putValue(long entryOffset, int value) {
+            long countOffset = entryOffset + COUNT_OFFSET;
+            U.putInt(data, countOffset, U.getInt(data, countOffset) + 1);
+            long minValueOffset = entryOffset + MIN_VALUE_OFFSET;
+            if (value < U.getShort(data, minValueOffset)) {
+                U.putShort(data, minValueOffset, (short) value);
+            }
+            long maxValueOffset = entryOffset + MAX_VALUE_OFFSET;
+            if (value > U.getShort(data, maxValueOffset)) {
+                U.putShort(data, maxValueOffset, (short) value);
+            }
+            long sumOffset = entryOffset + VALUE_SUM_OFFSET;
+            U.putLong(data, sumOffset, U.getLong(data, sumOffset) + value);
         }
 
         private void merge(Map<String, KeyResult> resultMap) {