Charlibot - use memory mapping (#372)
* add memory map approach * cleanup
This commit is contained in:
		| @@ -16,5 +16,5 @@ | |||||||
| # | # | ||||||
|  |  | ||||||
|  |  | ||||||
| JAVA_OPTS="" | JAVA_OPTS="--enable-preview" | ||||||
| java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_charlibot | java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_charlibot | ||||||
|   | |||||||
							
								
								
									
										19
									
								
								prepare_charlibot.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										19
									
								
								prepare_charlibot.sh
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,19 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # | ||||||
|  | #  Copyright 2023 The original authors | ||||||
|  | # | ||||||
|  | #  Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | #  you may not use this file except in compliance with the License. | ||||||
|  | #  You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #      http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | #  Unless required by applicable law or agreed to in writing, software | ||||||
|  | #  distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | #  See the License for the specific language governing permissions and | ||||||
|  | #  limitations under the License. | ||||||
|  | # | ||||||
|  |  | ||||||
|  | source "$HOME/.sdkman/bin/sdkman-init.sh" | ||||||
|  | sdk use java 21.0.1-graal 1>&2 | ||||||
| @@ -15,9 +15,14 @@ | |||||||
|  */ |  */ | ||||||
| package dev.morling.onebrc; | package dev.morling.onebrc; | ||||||
|  |  | ||||||
| import java.io.*; | import sun.misc.Unsafe; | ||||||
|  |  | ||||||
|  | import java.lang.foreign.Arena; | ||||||
|  | import java.lang.reflect.Field; | ||||||
|  | import java.nio.channels.FileChannel; | ||||||
| import java.nio.charset.StandardCharsets; | import java.nio.charset.StandardCharsets; | ||||||
| import java.nio.file.Path; | import java.nio.file.Path; | ||||||
|  | import java.nio.file.StandardOpenOption; | ||||||
| import java.util.*; | import java.util.*; | ||||||
| import java.util.concurrent.*; | import java.util.concurrent.*; | ||||||
| import java.util.stream.Collectors; | import java.util.stream.Collectors; | ||||||
| @@ -26,12 +31,23 @@ public class CalculateAverage_charlibot { | |||||||
|  |  | ||||||
|     private static final String FILE = "./measurements.txt"; |     private static final String FILE = "./measurements.txt"; | ||||||
|  |  | ||||||
|     private static final int BUFFER_SIZE = 1024 * 1024 * 10; |     private static final Unsafe UNSAFE = initUnsafe(); | ||||||
|  |  | ||||||
|  |     private static Unsafe initUnsafe() { | ||||||
|  |         try { | ||||||
|  |             final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); | ||||||
|  |             theUnsafe.setAccessible(true); | ||||||
|  |             return (Unsafe) theUnsafe.get(Unsafe.class); | ||||||
|  |         } | ||||||
|  |         catch (NoSuchFieldException | IllegalAccessException e) { | ||||||
|  |             throw new RuntimeException(e); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     private static final int MAP_CAPACITY = 16384; // Need at least 10,000 so 2^14 = 16384. Might need 2^15 = 32768. |     private static final int MAP_CAPACITY = 16384; // Need at least 10,000 so 2^14 = 16384. Might need 2^15 = 32768. | ||||||
|  |  | ||||||
|     public static void main(String[] args) throws Exception { |     public static void main(String[] args) throws Exception { | ||||||
|         multiThreadedReadingDoItAll(); |         memoryMap(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // Copied from Roy van Rijn's code |     // Copied from Roy van Rijn's code | ||||||
| @@ -75,111 +91,74 @@ public class CalculateAverage_charlibot { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     static int hashArraySlice(byte[] array, int offset, int length) { |     static class MeasurementMap3 { | ||||||
|         int hashcode = 0; |  | ||||||
|         for (int i = offset; i < offset + length; i++) { |  | ||||||
|             hashcode = 31 * hashcode + array[i]; |  | ||||||
|         } |  | ||||||
|         // Not sure the below actually helps much? |  | ||||||
|         // hashcode = hashcode >>> 16; // Do the same trick as-in hashmap since we're using power of 2 |  | ||||||
|         return hashcode; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     static class MeasurementMap { |         final Measurement[] measurements; | ||||||
|  |         final byte[][] cities; | ||||||
|  |  | ||||||
|         final int[][] map; |  | ||||||
|         final int capacity = MAP_CAPACITY; |         final int capacity = MAP_CAPACITY; | ||||||
|  |  | ||||||
|         final int numIntsToStoreCity = 25; // stores up to 100 characters. |         MeasurementMap3() { | ||||||
|         int minPos = numIntsToStoreCity; |             measurements = new Measurement[capacity]; | ||||||
|         int maxPos = numIntsToStoreCity + 1; |             cities = new byte[capacity][128]; // 100 bytes for the city. Round up to nearest power of 2. | ||||||
|         int sumPos = numIntsToStoreCity + 2; |  | ||||||
|         int countPos = numIntsToStoreCity + 3; |  | ||||||
|  |  | ||||||
|         MeasurementMap() { |  | ||||||
|             map = new int[capacity][numIntsToStoreCity + 4]; // length of string and then the city encoded cast bytes to int. then min, max, sum, count, |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         public void insert(byte[] array, int offset, int length, int value) { |         public void insert(long fromAddress, long toAddress, int hashcode, int value) { | ||||||
|             int hashcode = hashArraySlice(array, offset, length); |  | ||||||
|             int index = hashcode & (capacity - 1); // same trick as in hashmap. This is the same as (% capacity). |             int index = hashcode & (capacity - 1); // same trick as in hashmap. This is the same as (% capacity). | ||||||
|             tryInsert(index, array, offset, length, value); |             tryInsert(index, fromAddress, toAddress, value); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         private void tryInsert(int mapIndex, byte[] array, int offset, int length, int value) { |         private void tryInsert(int mapIndex, long fromAddress, long toAddress, int value) { | ||||||
|  |             byte length = (byte) (toAddress - fromAddress); | ||||||
|             outer: while (true) { |             outer: while (true) { | ||||||
|                 int[] jas = map[mapIndex]; |                 byte[] cityArray = cities[mapIndex]; | ||||||
|                 if (jas[0] == 0) { |                 Measurement jas = measurements[mapIndex]; | ||||||
|                     // just insert |                 if (jas != null) { | ||||||
|                     int i = 0; |                     if (cityArray[0] == length) { | ||||||
|                     int jasIndex = -1; |                         int i = 0; | ||||||
|                     while (i < length) { |                         while (i < length) { | ||||||
|                         byte b = array[i + offset]; |                             byte b = UNSAFE.getByte(fromAddress + i); | ||||||
|                         // i & 3 is the same as i % 4 |                             if (b != cityArray[i + 1]) { | ||||||
|                         if ((i & 3) == 0) { // when at i=0,4,8,12 then |                                 mapIndex = (mapIndex + 1) & (capacity - 1); | ||||||
|                             jasIndex++; |                                 continue outer; | ||||||
|  |                             } | ||||||
|  |                             i++; | ||||||
|                         } |                         } | ||||||
|                         jas[jasIndex] = jas[jasIndex] | ((b & 0xFF) << (8 * (i & 3))); |                         jas.min = min(value, jas.min); | ||||||
|                         i++; |                         jas.max = max(value, jas.max); | ||||||
|  |                         jas.sum += value; | ||||||
|  |                         jas.count += 1; | ||||||
|  |                         break; | ||||||
|  |                     } | ||||||
|  |                     else { | ||||||
|  |                         mapIndex = (mapIndex + 1) & (capacity - 1); | ||||||
|                     } |                     } | ||||||
|                     jas[minPos] = value; |  | ||||||
|                     jas[maxPos] = value; |  | ||||||
|                     jas[sumPos] = value; |  | ||||||
|                     jas[countPos] = 1; |  | ||||||
|                     break; |  | ||||||
|                 } |                 } | ||||||
|                 else { |                 else { | ||||||
|  |                     // just insert | ||||||
|                     int i = 0; |                     int i = 0; | ||||||
|                     int jasIndex = -1; |                     cityArray[0] = length; | ||||||
|                     while (i < length) { |                     while (i < length) { | ||||||
|                         byte b = array[i + offset]; |                         byte b = UNSAFE.getByte(fromAddress + i); | ||||||
|                         if ((i & 3) == 0) { // when at i=0,4,8,12,... then |                         cityArray[i + 1] = b; | ||||||
|                             jasIndex++; |  | ||||||
|                         } |  | ||||||
|                         byte inJas = (byte) (jas[jasIndex] >>> (8 * (i & 3))); |  | ||||||
|                         if (b != inJas) { |  | ||||||
|                             mapIndex = (mapIndex + 1) & (capacity - 1); |  | ||||||
|                             continue outer; |  | ||||||
|                         } |  | ||||||
|                         i++; |                         i++; | ||||||
|                     } |                     } | ||||||
|                     jas[minPos] = min(value, jas[minPos]); |                     measurements[mapIndex] = new Measurement(value); | ||||||
|                     jas[maxPos] = max(value, jas[maxPos]); |  | ||||||
|                     jas[sumPos] += value; |  | ||||||
|                     jas[countPos] += 1; |  | ||||||
|                     break; |                     break; | ||||||
|  |  | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         public HashMap<String, Measurement> toMap() { |         public HashMap<String, Measurement> toMap() { | ||||||
|             HashMap<String, Measurement> hashMap = new HashMap<>(); |             HashMap<String, Measurement> hashMap = new HashMap<>(); | ||||||
|             for (int[] jas : map) { |             for (int mapIndex = 0; mapIndex < cities.length; mapIndex++) { | ||||||
|                 if (jas[0] != 0) { |                 byte[] cityArray = cities[mapIndex]; | ||||||
|                     int jasIndex = 0; |                 Measurement measurement = measurements[mapIndex]; | ||||||
|                     byte[] array = new byte[numIntsToStoreCity * 4]; |                 if (measurement != null) { | ||||||
|                     while (jasIndex < numIntsToStoreCity) { |                     int length = cityArray[0]; | ||||||
|                         int tmp = jas[jasIndex]; |                     String city = new String(cityArray, 1, length, StandardCharsets.UTF_8); | ||||||
|                         array[jasIndex * 4] = (byte) tmp; |                     hashMap.put(city, measurement); | ||||||
|                         array[jasIndex * 4 + 1] = (byte) (tmp >>> 8); |  | ||||||
|                         array[jasIndex * 4 + 2] = (byte) (tmp >>> 16); |  | ||||||
|                         array[jasIndex * 4 + 3] = (byte) (tmp >>> 24); |  | ||||||
|                         jasIndex++; |  | ||||||
|                     } |  | ||||||
|                     int length = array.length; |  | ||||||
|                     for (int i = 0; i < array.length; i++) { |  | ||||||
|                         if (array[i] == 0) { |  | ||||||
|                             length = i; |  | ||||||
|                             break; |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                     String city = new String(array, 0, length, StandardCharsets.UTF_8); |  | ||||||
|                     Measurement m = new Measurement(0); |  | ||||||
|                     m.min = jas[minPos]; |  | ||||||
|                     m.max = jas[maxPos]; |  | ||||||
|                     m.sum = jas[sumPos]; |  | ||||||
|                     m.count = jas[countPos]; |  | ||||||
|                     hashMap.put(city, m); |  | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             return hashMap; |             return hashMap; | ||||||
| @@ -190,124 +169,68 @@ public class CalculateAverage_charlibot { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     public static void multiThreadedReadingDoItAll() throws Exception { |     public static long[] getChunks(int numChunks) throws Exception { | ||||||
|         File file = Path.of(FILE).toFile(); |         long[] chunks = new long[numChunks + 1]; | ||||||
|         long length = file.length(); |         try (FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { | ||||||
|         int numProcessors = Runtime.getRuntime().availableProcessors(); |             long fileSize = fileChannel.size(); | ||||||
|         long chunkToRead = length / numProcessors; |             long sizeOfChunk = fileSize / numChunks; | ||||||
|  |             var address = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); | ||||||
|         // make life easier by spending a bit of time up front to find line breaks around the chunks |             chunks[0] = address; | ||||||
|         final long[] startPositions = new long[numProcessors + 1]; |             for (int processIdx = 1; processIdx < numChunks; processIdx++) { | ||||||
|         try (RandomAccessFile raf = new RandomAccessFile(file, "r")) { |                 long chunkAddress = processIdx * sizeOfChunk + address; | ||||||
|             byte[] buffer = new byte[256]; |                 while (UNSAFE.getByte(chunkAddress) != '\n') { | ||||||
|             for (int processIdx = 1; processIdx < numProcessors; processIdx++) { |                     chunkAddress++; | ||||||
|                 long initialSeekPoint = processIdx * chunkToRead; |  | ||||||
|                 raf.seek(initialSeekPoint); |  | ||||||
|                 int bytesRead = raf.read(buffer); |  | ||||||
|                 // if (bytesRead != buffer.length) { |  | ||||||
|                 // throw new Exception("Actual read is not same as requested. " + bytesRead); |  | ||||||
|                 // } |  | ||||||
|                 int i = 0; |  | ||||||
|                 while (buffer[i] != '\n') { |  | ||||||
|                     i++; |  | ||||||
|                 } |                 } | ||||||
|                 initialSeekPoint += (i + 1); |                 chunkAddress++; | ||||||
|                 startPositions[processIdx] = initialSeekPoint; |                 chunks[processIdx] = chunkAddress; | ||||||
|             } |             } | ||||||
|             startPositions[numProcessors] = length; |             chunks[numChunks] = address + fileSize; | ||||||
|         } |         } | ||||||
|  |         return chunks; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     public static void memoryMap() throws Exception { | ||||||
|  |         int numProcessors = Runtime.getRuntime().availableProcessors(); | ||||||
|  |         long[] chunks = getChunks(numProcessors); | ||||||
|         try (ExecutorService executorService = Executors.newWorkStealingPool(numProcessors)) { |         try (ExecutorService executorService = Executors.newWorkStealingPool(numProcessors)) { | ||||||
|             Future[] results = new Future[numProcessors]; |             Future[] results = new Future[numProcessors]; | ||||||
|             for (int processIdx = 0; processIdx < numProcessors; processIdx++) { |             for (int processIdx = 0; processIdx < numProcessors; processIdx++) { | ||||||
|                 long seekPoint = startPositions[processIdx]; |                 int finalProcessIdx = processIdx; | ||||||
|                 long bytesToRead = startPositions[processIdx + 1] - startPositions[processIdx]; |  | ||||||
|                 Future<HashMap<String, Measurement>> future = executorService.submit(() -> { |                 Future<HashMap<String, Measurement>> future = executorService.submit(() -> { | ||||||
|                     MeasurementMap measurements = new MeasurementMap(); |                     long chunkIdx = chunks[finalProcessIdx]; | ||||||
|                     try (FileInputStream fis = new FileInputStream(file)) { |                     long chunkEnd = chunks[finalProcessIdx + 1]; | ||||||
|                         long actualSkipped = fis.skip(seekPoint); |                     MeasurementMap3 measurements = new MeasurementMap3(); | ||||||
|                         if (actualSkipped != seekPoint) { |                     while (chunkIdx < chunkEnd) { | ||||||
|                             throw new Exception("Uho oh"); |                         long cityStart = chunkIdx; | ||||||
|  |                         byte b; | ||||||
|  |                         int hashcode = 0; | ||||||
|  |                         while ((b = UNSAFE.getByte(chunkIdx)) != ';') { | ||||||
|  |                             hashcode = 31 * hashcode + b; | ||||||
|  |                             chunkIdx++; | ||||||
|                         } |                         } | ||||||
|                         byte[] buffer = new byte[BUFFER_SIZE]; |                         long cityEnd = chunkIdx; | ||||||
|                         long totalBytesRead = 0; |                         chunkIdx++; | ||||||
|                         int bytesRead; |                         int multiplier = 1; | ||||||
|                         int currentCityLength = 0; |                         b = UNSAFE.getByte(chunkIdx); | ||||||
|                         while ((bytesRead = fis.read(buffer, currentCityLength, buffer.length - currentCityLength)) != -1) { |                         if (b == '-') { | ||||||
|                             totalBytesRead -= currentCityLength; // avoid double counting. There must be a better way ! |                             multiplier = -1; | ||||||
|                             if (totalBytesRead >= bytesToRead && currentCityLength == 0) { |                             chunkIdx++; | ||||||
|                                 // we have read everything we intend to and there is no city in the buffer to finish processing |  | ||||||
|                                 return measurements.toMap(); |  | ||||||
|                             } |  | ||||||
|                             int i = 0; |  | ||||||
|                             int cityIndexStart = 0; |  | ||||||
|                             int cityLength; |  | ||||||
|                             int multiplier = 1; |  | ||||||
|                             int value = 0; |  | ||||||
|                             while (i < bytesRead + currentCityLength) { |  | ||||||
|                                 if (totalBytesRead >= bytesToRead) { |  | ||||||
|                                     // we have read everything we intend to for this chunk |  | ||||||
|                                     return measurements.toMap(); |  | ||||||
|                                 } |  | ||||||
|                                 if (buffer[i] == ';') { |  | ||||||
|                                     cityLength = i - cityIndexStart; |  | ||||||
|                                     i++; |  | ||||||
|                                     totalBytesRead++; |  | ||||||
|                                     if (i == bytesRead + currentCityLength) { |  | ||||||
|                                         System.arraycopy(buffer, cityIndexStart, buffer, 0, cityLength); |  | ||||||
|                                         bytesRead = fis.read(buffer, cityLength, buffer.length - cityLength); |  | ||||||
|                                         currentCityLength = cityLength; |  | ||||||
|                                         cityIndexStart = 0; |  | ||||||
|                                         i = cityLength; |  | ||||||
|                                     } |  | ||||||
|                                     if (buffer[i] == '-') { |  | ||||||
|                                         multiplier = -1; |  | ||||||
|                                         i++; |  | ||||||
|                                         totalBytesRead++; |  | ||||||
|                                         if (i == bytesRead + currentCityLength) { |  | ||||||
|                                             System.arraycopy(buffer, cityIndexStart, buffer, 0, cityLength); |  | ||||||
|                                             bytesRead = fis.read(buffer, cityLength, buffer.length - cityLength); |  | ||||||
|                                             currentCityLength = cityLength; |  | ||||||
|                                             cityIndexStart = 0; |  | ||||||
|                                             i = cityLength; |  | ||||||
|                                         } |  | ||||||
|                                     } |  | ||||||
|                                     while (buffer[i] != '\n') { |  | ||||||
|                                         if (buffer[i] != '.') { |  | ||||||
|                                             value = (value * 10) + (buffer[i] - '0'); |  | ||||||
|                                         } |  | ||||||
|                                         i++; |  | ||||||
|                                         totalBytesRead++; |  | ||||||
|                                         if (i == bytesRead + currentCityLength) { |  | ||||||
|                                             System.arraycopy(buffer, cityIndexStart, buffer, 0, cityLength); |  | ||||||
|                                             bytesRead = fis.read(buffer, cityLength, buffer.length - cityLength); |  | ||||||
|                                             currentCityLength = cityLength; |  | ||||||
|                                             cityIndexStart = 0; |  | ||||||
|                                             i = cityLength; |  | ||||||
|                                         } |  | ||||||
|                                     } |  | ||||||
|                                     value = value * multiplier; // is boolean check faster? |  | ||||||
|                                     measurements.insert(buffer, cityIndexStart, cityLength, value); |  | ||||||
|                                     if (totalBytesRead >= bytesToRead) { |  | ||||||
|                                         return measurements.toMap(); |  | ||||||
|                                     } |  | ||||||
|                                     // buffer[i] == \n so go one more |  | ||||||
|                                     cityIndexStart = i + 1; |  | ||||||
|                                     value = 0; |  | ||||||
|                                     multiplier = 1; |  | ||||||
|                                 } |  | ||||||
|                                 i++; |  | ||||||
|                                 totalBytesRead++; |  | ||||||
|                             } |  | ||||||
|                             currentCityLength = buffer.length - cityIndexStart; |  | ||||||
|                             System.arraycopy(buffer, cityIndexStart, buffer, 0, currentCityLength); |  | ||||||
|                         } |                         } | ||||||
|  |                         int value = 0; | ||||||
|  |                         while ((b = UNSAFE.getByte(chunkIdx)) != '\n') { | ||||||
|  |                             if (b != '.') { | ||||||
|  |                                 value = (value * 10) + (b - '0'); | ||||||
|  |                             } | ||||||
|  |                             chunkIdx++; | ||||||
|  |                         } | ||||||
|  |                         value = value * multiplier; | ||||||
|  |                         measurements.insert(cityStart, cityEnd, hashcode, value); | ||||||
|  |                         chunkIdx++; | ||||||
|                     } |                     } | ||||||
|                     return measurements.toMap(); |                     return measurements.toMap(); | ||||||
|                 }); |                 }); | ||||||
|                 results[processIdx] = future; |                 results[processIdx] = future; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             final HashMap<String, Measurement> measurements = new HashMap<>(); |             final HashMap<String, Measurement> measurements = new HashMap<>(); | ||||||
|             for (Future f : results) { |             for (Future f : results) { | ||||||
|                 HashMap<String, Measurement> m = (HashMap<String, Measurement>) f.get(); |                 HashMap<String, Measurement> m = (HashMap<String, Measurement>) f.get(); | ||||||
| @@ -328,5 +251,4 @@ public class CalculateAverage_charlibot { | |||||||
|             System.out.println("}"); |             System.out.println("}"); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user