Second submission to keep a bit of dignity (#581)
* Dmitry challenge * Dmitry submit 2. Use MemorySegment of FileChannle and Unsafe to read bytes from disk. 4 seconds speedup in local test from 20s to 16s.
This commit is contained in:
		| @@ -17,4 +17,5 @@ | ||||
|  | ||||
|  | ||||
| #JAVA_OPTS="-verbose:gc" | ||||
| JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation" | ||||
| java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_bufistov $1 $2 | ||||
|   | ||||
| @@ -15,11 +15,17 @@ | ||||
|  */ | ||||
| package dev.morling.onebrc; | ||||
|  | ||||
| import sun.misc.Unsafe; | ||||
|  | ||||
| import static java.lang.Math.toIntExact; | ||||
|  | ||||
| import java.lang.foreign.Arena; | ||||
| import java.lang.reflect.Field; | ||||
| import java.nio.MappedByteBuffer; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.nio.file.Paths; | ||||
| import java.time.Instant; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.Comparator; | ||||
| @@ -32,7 +38,27 @@ import java.io.FileInputStream; | ||||
| import java.io.IOException; | ||||
| import java.util.concurrent.Future; | ||||
|  | ||||
| class ResultRow { | ||||
| class ByteArrayWrapper { | ||||
|     private final byte[] data; | ||||
|  | ||||
|     public ByteArrayWrapper(byte[] data) { | ||||
|         this.data = data; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public boolean equals(Object other) { | ||||
|         return Arrays.equals(data, ((ByteArrayWrapper) other).data); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public int hashCode() { | ||||
|         return Arrays.hashCode(data); | ||||
|     } | ||||
| } | ||||
|  | ||||
| public class CalculateAverage_bufistov { | ||||
|  | ||||
|     static class ResultRow { | ||||
|         byte[] station; | ||||
|  | ||||
|         String stationString; | ||||
| @@ -57,9 +83,11 @@ class ResultRow { | ||||
|             this.suma = value; | ||||
|         } | ||||
|  | ||||
|     void setStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition) { | ||||
|         this.station = new byte[endPosition - startPosition]; | ||||
|         byteBuffer.slice(startPosition, station.length).get(this.station, 0, station.length); | ||||
|         void setStation(long startPosition, long endPosition) { | ||||
|             this.station = new byte[(int) (endPosition - startPosition)]; | ||||
|             for (int i = 0; i < this.station.length; ++i) { | ||||
|                 this.station[i] = UNSAFE.getByte(startPosition + i); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public String toString() { | ||||
| @@ -71,7 +99,7 @@ class ResultRow { | ||||
|             return Math.round(value * 10.0) / 10.0; | ||||
|         } | ||||
|  | ||||
|     ResultRow update(long newValue) { | ||||
|         void update(long newValue) { | ||||
|             this.count += 1; | ||||
|             this.suma += newValue; | ||||
|             if (newValue < this.min) { | ||||
| @@ -80,7 +108,6 @@ class ResultRow { | ||||
|             else if (newValue > this.max) { | ||||
|                 this.max = newValue; | ||||
|             } | ||||
|         return this; | ||||
|         } | ||||
|  | ||||
|         ResultRow merge(ResultRow another) { | ||||
| @@ -90,27 +117,9 @@ class ResultRow { | ||||
|             this.max = Math.max(this.max, another.max); | ||||
|             return this; | ||||
|         } | ||||
| } | ||||
|  | ||||
| class ByteArrayWrapper { | ||||
|     private final byte[] data; | ||||
|  | ||||
|     public ByteArrayWrapper(byte[] data) { | ||||
|         this.data = data; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public boolean equals(Object other) { | ||||
|         return Arrays.equals(data, ((ByteArrayWrapper) other).data); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public int hashCode() { | ||||
|         return Arrays.hashCode(data); | ||||
|     } | ||||
| } | ||||
|  | ||||
| class OpenHash { | ||||
|     static class OpenHash { | ||||
|         ResultRow[] data; | ||||
|         int dataSizeMask; | ||||
|  | ||||
| @@ -150,26 +159,26 @@ class OpenHash { | ||||
|             merge(station, value, hashByteArray(station)); | ||||
|         } | ||||
|  | ||||
|     void merge(MappedByteBuffer byteBuffer, final int startPosition, final int endPosition, int hashValue, final long value) { | ||||
|         while (data[hashValue] != null && !equalsToStation(byteBuffer, startPosition, endPosition, data[hashValue].station)) { | ||||
|         void merge(final long startPosition, long endPosition, int hashValue, long value) { | ||||
|             while (data[hashValue] != null && !equalsToStation(startPosition, endPosition, data[hashValue].station)) { | ||||
|                 hashValue += 1; | ||||
|                 hashValue &= dataSizeMask; | ||||
|             } | ||||
|             if (data[hashValue] == null) { | ||||
|                 data[hashValue] = new ResultRow(value); | ||||
|             data[hashValue].setStation(byteBuffer, startPosition, endPosition); | ||||
|                 data[hashValue].setStation(startPosition, endPosition); | ||||
|             } | ||||
|             else { | ||||
|                 data[hashValue].update(value); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|     boolean equalsToStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition, byte[] station) { | ||||
|         boolean equalsToStation(long startPosition, long endPosition, byte[] station) { | ||||
|             if (endPosition - startPosition != station.length) { | ||||
|                 return false; | ||||
|             } | ||||
|             for (int i = 0; i < station.length; ++i, ++startPosition) { | ||||
|             if (byteBuffer.get(startPosition) != station[i]) | ||||
|                 if (UNSAFE.getByte(startPosition) != station[i]) | ||||
|                     return false; | ||||
|             } | ||||
|             return true; | ||||
| @@ -185,25 +194,38 @@ class OpenHash { | ||||
|             } | ||||
|             return result; | ||||
|         } | ||||
| } | ||||
|     } | ||||
|  | ||||
| public class CalculateAverage_bufistov { | ||||
|     static final Unsafe UNSAFE; | ||||
|  | ||||
|     static { | ||||
|         try { | ||||
|             Field unsafe = Unsafe.class.getDeclaredField("theUnsafe"); | ||||
|             unsafe.setAccessible(true); | ||||
|             UNSAFE = (Unsafe) unsafe.get(Unsafe.class); | ||||
|         } | ||||
|         catch (Throwable e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     static final long LINE_SEPARATOR = '\n'; | ||||
|  | ||||
|     public static class FileRead implements Callable<HashMap<ByteArrayWrapper, ResultRow>> { | ||||
|  | ||||
|         private final FileChannel fileChannel; | ||||
|  | ||||
|         private long currentLocation; | ||||
|         private int bytesToRead; | ||||
|         private long bytesToRead; | ||||
|  | ||||
|         private final int hashCapacityPow2 = 18; | ||||
|         private final int hashCapacityMask = (1 << hashCapacityPow2) - 1; | ||||
|         private static final int hashCapacityPow2 = 18; | ||||
|  | ||||
|         public FileRead(long startLocation, int bytesToRead, FileChannel fileChannel) { | ||||
|         static final int hashCapacityMask = (1 << hashCapacityPow2) - 1; | ||||
|  | ||||
|         public FileRead(FileChannel fileChannel, long startLocation, long bytesToRead, boolean firstSegment) { | ||||
|             this.fileChannel = fileChannel; | ||||
|             this.currentLocation = startLocation; | ||||
|             this.bytesToRead = bytesToRead; | ||||
|             this.fileChannel = fileChannel; | ||||
|         } | ||||
|  | ||||
|         @Override | ||||
| @@ -211,21 +233,13 @@ public class CalculateAverage_bufistov { | ||||
|             try { | ||||
|                 OpenHash openHash = new OpenHash(hashCapacityPow2); | ||||
|                 log("Reading the channel: " + currentLocation + ":" + bytesToRead); | ||||
|                 byte[] suffix = new byte[128]; | ||||
|                 if (currentLocation > 0) { | ||||
|                     toLineBegin(suffix); | ||||
|                 } | ||||
|                 while (bytesToRead > 0) { | ||||
|                     int bufferSize = Math.min(1 << 24, bytesToRead); | ||||
|                     MappedByteBuffer byteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, currentLocation, bufferSize); | ||||
|                     bytesToRead -= bufferSize; | ||||
|                     currentLocation += bufferSize; | ||||
|                     int suffixBytes = 0; | ||||
|                     if (currentLocation < fileChannel.size()) { | ||||
|                         suffixBytes = toLineBegin(suffix); | ||||
|                     } | ||||
|                     processChunk(byteBuffer, bufferSize, suffix, suffixBytes, openHash); | ||||
|                     toLineBeginPrefix(); | ||||
|                 } | ||||
|                 toLineBeginSuffix(); | ||||
|                 var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, currentLocation, bytesToRead, Arena.global()); | ||||
|                 currentLocation = memorySegment.address(); | ||||
|                 processChunk(openHash); | ||||
|                 log("Done Reading the channel: " + currentLocation + ":" + bytesToRead); | ||||
|                 return openHash.toJavaHashMap(); | ||||
|             } | ||||
| @@ -240,39 +254,40 @@ public class CalculateAverage_bufistov { | ||||
|             return byteBuffer.get(); | ||||
|         } | ||||
|  | ||||
|         int toLineBegin(byte[] suffix) throws IOException { | ||||
|             int bytesConsumed = 0; | ||||
|             if (getByte(currentLocation - 1) != LINE_SEPARATOR) { | ||||
|                 while (getByte(currentLocation) != LINE_SEPARATOR) { // Small bug here if last chunk is less than a line and has no '\n' at the end. Valid input should have '\n' for all rows. | ||||
|                     suffix[bytesConsumed++] = getByte(currentLocation); | ||||
|         void toLineBeginPrefix() throws IOException { | ||||
|             while (getByte(currentLocation - 1) != LINE_SEPARATOR) { | ||||
|                 ++currentLocation; | ||||
|                 --bytesToRead; | ||||
|             } | ||||
|                 ++currentLocation; | ||||
|                 --bytesToRead; | ||||
|             } | ||||
|             return bytesConsumed; | ||||
|         } | ||||
|  | ||||
|         void processChunk(MappedByteBuffer byteBuffer, int bufferSize, byte[] suffix, int suffixBytes, OpenHash result) { | ||||
|             int nameBegin = 0; | ||||
|             int nameEnd = -1; | ||||
|             int numberBegin = -1; | ||||
|         void toLineBeginSuffix() throws IOException { | ||||
|             while (getByte(currentLocation + bytesToRead - 1) != LINE_SEPARATOR) { | ||||
|                 ++bytesToRead; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         void processChunk(OpenHash result) { | ||||
|             long nameBegin = currentLocation; | ||||
|             long nameEnd = -1; | ||||
|             long numberBegin = -1; | ||||
|             int currentHash = 0; | ||||
|             int currentMask = 0; | ||||
|             int nameHash = 0; | ||||
|             for (int currentPosition = 0; currentPosition < bufferSize; ++currentPosition) { | ||||
|                 byte nextByte = byteBuffer.get(currentPosition); | ||||
|             long end = currentLocation + bytesToRead; | ||||
|             byte nextByte; | ||||
|             for (; currentLocation < end; ++currentLocation) { | ||||
|                 nextByte = UNSAFE.getByte(currentLocation); | ||||
|                 if (nextByte == ';') { | ||||
|                     nameEnd = currentPosition; | ||||
|                     numberBegin = currentPosition + 1; | ||||
|                     nameEnd = currentLocation; | ||||
|                     numberBegin = currentLocation + 1; | ||||
|                     nameHash = currentHash & hashCapacityMask; | ||||
|                 } | ||||
|                 else if (nextByte == LINE_SEPARATOR) { | ||||
|                     long value = getValue(byteBuffer, numberBegin, currentPosition); | ||||
|                     // log("Station name: '" + getStationName(byteBuffer, nameBegin, nameEnd) + "' value: " + value + " hash: " + nameHash); | ||||
|                     result.merge(byteBuffer, nameBegin, nameEnd, nameHash, value); | ||||
|                     nameBegin = currentPosition + 1; | ||||
|                     long value = getValue(numberBegin, currentLocation); | ||||
|                     // log("Station name: '" + getStationName(nameBegin, nameEnd) + "' value: " + value + " hash: " + nameHash); | ||||
|                     result.merge(nameBegin, nameEnd, nameHash, value); | ||||
|                     nameBegin = currentLocation + 1; | ||||
|                     currentHash = 0; | ||||
|                     currentMask = 0; | ||||
|                 } | ||||
| @@ -281,38 +296,14 @@ public class CalculateAverage_bufistov { | ||||
|                     currentMask = (currentMask + 1) & 3; | ||||
|                 } | ||||
|             } | ||||
|             if (nameBegin < bufferSize) { | ||||
|                 byte[] lastLine = new byte[bufferSize - nameBegin + suffixBytes]; | ||||
|                 byte[] prefix = new byte[bufferSize - nameBegin]; | ||||
|                 byteBuffer.slice(nameBegin, prefix.length).get(prefix, 0, prefix.length); | ||||
|                 System.arraycopy(prefix, 0, lastLine, 0, prefix.length); | ||||
|                 System.arraycopy(suffix, 0, lastLine, prefix.length, suffixBytes); | ||||
|                 processLastLine(lastLine, result); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         void processLastLine(byte[] lastLine, OpenHash result) { | ||||
|             int numberBegin = -1; | ||||
|             byte[] stationName = null; | ||||
|             for (int i = 0; i < lastLine.length; ++i) { | ||||
|                 if (lastLine[i] == ';') { | ||||
|                     stationName = new byte[i]; | ||||
|                     System.arraycopy(lastLine, 0, stationName, 0, stationName.length); | ||||
|                     numberBegin = i + 1; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             long value = getValue(lastLine, numberBegin); | ||||
|             // log("Station name: '" + new String(stationName, StandardCharsets.UTF_8) + "' value: " + value); | ||||
|             result.merge(stationName, value); | ||||
|         } | ||||
|  | ||||
|         long getValue(MappedByteBuffer byteBuffer, int startLocation, int endLocation) { | ||||
|             byte nextByte = byteBuffer.get(startLocation); | ||||
|         long getValue(long startLocation, long endLocation) { | ||||
|             byte nextByte = UNSAFE.getByte(startLocation); | ||||
|             boolean negate = nextByte == '-'; | ||||
|             long result = negate ? 0 : nextByte - '0'; | ||||
|             for (int i = startLocation + 1; i < endLocation; ++i) { | ||||
|                 nextByte = byteBuffer.get(i); | ||||
|             for (long i = startLocation + 1; i < endLocation; ++i) { | ||||
|                 nextByte = UNSAFE.getByte(i); | ||||
|                 if (nextByte != '.') { | ||||
|                     result *= 10; | ||||
|                     result += nextByte - '0'; | ||||
| @@ -321,23 +312,11 @@ public class CalculateAverage_bufistov { | ||||
|             return negate ? -result : result; | ||||
|         } | ||||
|  | ||||
|         long getValue(byte[] lastLine, int startLocation) { | ||||
|             byte nextByte = lastLine[startLocation]; | ||||
|             boolean negate = nextByte == '-'; | ||||
|             long result = negate ? 0 : nextByte - '0'; | ||||
|             for (int i = startLocation + 1; i < lastLine.length; ++i) { | ||||
|                 nextByte = lastLine[i]; | ||||
|                 if (nextByte != '.') { | ||||
|                     result *= 10; | ||||
|                     result += nextByte - '0'; | ||||
|         String getStationName(long from, long to) { | ||||
|             byte[] bytes = new byte[(int) (to - from)]; | ||||
|             for (int i = 0; i < bytes.length; ++i) { | ||||
|                 bytes[i] = UNSAFE.getByte(from + i); | ||||
|             } | ||||
|             } | ||||
|             return negate ? -result : result; | ||||
|         } | ||||
|  | ||||
|         String getStationName(MappedByteBuffer byteBuffer, int from, int to) { | ||||
|             byte[] bytes = new byte[to - from]; | ||||
|             byteBuffer.slice(from, to - from).get(0, bytes); | ||||
|             return new String(bytes, StandardCharsets.UTF_8); | ||||
|         } | ||||
|     } | ||||
| @@ -349,7 +328,7 @@ public class CalculateAverage_bufistov { | ||||
|         } | ||||
|         log("InputFile: " + fileName); | ||||
|         FileInputStream fileInputStream = new FileInputStream(fileName); | ||||
|         int numThreads = 32; | ||||
|         int numThreads = 2 * Runtime.getRuntime().availableProcessors(); | ||||
|         if (args.length > 1) { | ||||
|             numThreads = Integer.parseInt(args[1]); | ||||
|         } | ||||
| @@ -363,9 +342,12 @@ public class CalculateAverage_bufistov { | ||||
|  | ||||
|         long startLocation = 0; | ||||
|         ArrayList<Future<HashMap<ByteArrayWrapper, ResultRow>>> results = new ArrayList<>(numThreads); | ||||
|         var fileChannel = FileChannel.open(Paths.get(fileName)); | ||||
|         boolean firstSegment = true; | ||||
|         while (remaining_size > 0) { | ||||
|             long actualSize = Math.min(chunk_size, remaining_size); | ||||
|             results.add(executor.submit(new FileRead(startLocation, toIntExact(actualSize), channel))); | ||||
|             results.add(executor.submit(new FileRead(fileChannel, startLocation, toIntExact(actualSize), firstSegment))); | ||||
|             firstSegment = false; | ||||
|             remaining_size -= actualSize; | ||||
|             startLocation += actualSize; | ||||
|         } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user