Dmitry challenge
This commit is contained in:
		
				
					committed by
					
						 Gunnar Morling
						Gunnar Morling
					
				
			
			
				
	
			
			
			
						parent
						
							3c36b5b0a8
						
					
				
				
					commit
					0ca7c485aa
				
			
							
								
								
									
										398
									
								
								src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										398
									
								
								src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,398 @@ | ||||
| /* | ||||
|  *  Copyright 2023 The original authors | ||||
|  * | ||||
|  *  Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  *  you may not use this file except in compliance with the License. | ||||
|  *  You may obtain a copy of the License at | ||||
|  * | ||||
|  *      http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  *  Unless required by applicable law or agreed to in writing, software | ||||
|  *  distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  *  See the License for the specific language governing permissions and | ||||
|  *  limitations under the License. | ||||
|  */ | ||||
| package dev.morling.onebrc; | ||||
|  | ||||
| import static java.lang.Math.toIntExact; | ||||
|  | ||||
| import java.nio.MappedByteBuffer; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.Comparator; | ||||
| import java.util.HashMap; | ||||
| import java.util.concurrent.Callable; | ||||
| import java.util.concurrent.ExecutorService; | ||||
| import java.util.concurrent.Executors; | ||||
|  | ||||
| import java.io.FileInputStream; | ||||
| import java.io.IOException; | ||||
| import java.util.concurrent.Future; | ||||
|  | ||||
| class ResultRow { | ||||
|     byte[] station; | ||||
|  | ||||
|     String stationString; | ||||
|     long min, max, count, suma; | ||||
|  | ||||
|     ResultRow() { | ||||
|     } | ||||
|  | ||||
|     ResultRow(byte[] station, long value) { | ||||
|         this.station = new byte[station.length]; | ||||
|         System.arraycopy(station, 0, this.station, 0, station.length); | ||||
|         this.min = value; | ||||
|         this.max = value; | ||||
|         this.count = 1; | ||||
|         this.suma = value; | ||||
|     } | ||||
|  | ||||
|     ResultRow(long value) { | ||||
|         this.min = value; | ||||
|         this.max = value; | ||||
|         this.count = 1; | ||||
|         this.suma = value; | ||||
|     } | ||||
|  | ||||
|     void setStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition) { | ||||
|         this.station = new byte[endPosition - startPosition]; | ||||
|         byteBuffer.slice(startPosition, station.length).get(this.station, 0, station.length); | ||||
|     } | ||||
|  | ||||
|     public String toString() { | ||||
|         stationString = new String(station, StandardCharsets.UTF_8); | ||||
|         return stationString + "=" + round(min / 10.0) + "/" + round(suma / 10.0 / count) + "/" + round(max / 10.0); | ||||
|     } | ||||
|  | ||||
|     private double round(double value) { | ||||
|         return Math.round(value * 10.0) / 10.0; | ||||
|     } | ||||
|  | ||||
|     ResultRow update(long newValue) { | ||||
|         this.count += 1; | ||||
|         this.suma += newValue; | ||||
|         if (newValue < this.min) { | ||||
|             this.min = newValue; | ||||
|         } | ||||
|         else if (newValue > this.max) { | ||||
|             this.max = newValue; | ||||
|         } | ||||
|         return this; | ||||
|     } | ||||
|  | ||||
|     ResultRow merge(ResultRow another) { | ||||
|         this.count += another.count; | ||||
|         this.suma += another.suma; | ||||
|         this.min = Math.min(this.min, another.min); | ||||
|         this.max = Math.max(this.max, another.max); | ||||
|         return this; | ||||
|     } | ||||
| } | ||||
|  | ||||
| class ByteArrayWrapper { | ||||
|     private final byte[] data; | ||||
|  | ||||
|     public ByteArrayWrapper(byte[] data) { | ||||
|         this.data = data; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public boolean equals(Object other) { | ||||
|         return Arrays.equals(data, ((ByteArrayWrapper) other).data); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public int hashCode() { | ||||
|         return Arrays.hashCode(data); | ||||
|     } | ||||
| } | ||||
|  | ||||
| class OpenHash { | ||||
|     ResultRow[] data; | ||||
|     int dataSizeMask; | ||||
|  | ||||
|     // ResultRow metrics = new ResultRow(); | ||||
|  | ||||
|     public OpenHash(int capacityPow2) { | ||||
|         assert capacityPow2 <= 20; | ||||
|         int dataSize = 1 << capacityPow2; | ||||
|         dataSizeMask = dataSize - 1; | ||||
|         data = new ResultRow[dataSize]; | ||||
|     } | ||||
|  | ||||
|     int hashByteArray(byte[] array) { | ||||
|         int result = 0; | ||||
|         long mask = 0; | ||||
|         for (int i = 0; i < array.length; ++i, mask = ((mask + 1) & 3)) { | ||||
|             result += array[i] << mask; | ||||
|         } | ||||
|         return result & dataSizeMask; | ||||
|     } | ||||
|  | ||||
|     void merge(byte[] station, long value, int hashValue) { | ||||
|         while (data[hashValue] != null && !Arrays.equals(station, data[hashValue].station)) { | ||||
|             hashValue += 1; | ||||
|             hashValue &= dataSizeMask; | ||||
|         } | ||||
|         if (data[hashValue] == null) { | ||||
|             data[hashValue] = new ResultRow(station, value); | ||||
|         } | ||||
|         else { | ||||
|             data[hashValue].update(value); | ||||
|         } | ||||
|         // metrics.update(delta); | ||||
|     } | ||||
|  | ||||
|     void merge(byte[] station, long value) { | ||||
|         merge(station, value, hashByteArray(station)); | ||||
|     } | ||||
|  | ||||
|     void merge(MappedByteBuffer byteBuffer, final int startPosition, final int endPosition, int hashValue, final long value) { | ||||
|         while (data[hashValue] != null && !equalsToStation(byteBuffer, startPosition, endPosition, data[hashValue].station)) { | ||||
|             hashValue += 1; | ||||
|             hashValue &= dataSizeMask; | ||||
|         } | ||||
|         if (data[hashValue] == null) { | ||||
|             data[hashValue] = new ResultRow(value); | ||||
|             data[hashValue].setStation(byteBuffer, startPosition, endPosition); | ||||
|         } | ||||
|         else { | ||||
|             data[hashValue].update(value); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     boolean equalsToStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition, byte[] station) { | ||||
|         if (endPosition - startPosition != station.length) { | ||||
|             return false; | ||||
|         } | ||||
|         for (int i = 0; i < station.length; ++i, ++startPosition) { | ||||
|             if (byteBuffer.get(startPosition) != station[i]) | ||||
|                 return false; | ||||
|         } | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     HashMap<ByteArrayWrapper, ResultRow> toJavaHashMap() { | ||||
|         HashMap<ByteArrayWrapper, ResultRow> result = new HashMap<>(20000); | ||||
|         for (int i = 0; i < data.length; ++i) { | ||||
|             if (data[i] != null) { | ||||
|                 var key = new ByteArrayWrapper(data[i].station); | ||||
|                 result.put(key, data[i]); | ||||
|             } | ||||
|         } | ||||
|         return result; | ||||
|     } | ||||
| } | ||||
|  | ||||
| public class CalculateAverage_bufistov { | ||||
|  | ||||
|     static final long LINE_SEPARATOR = '\n'; | ||||
|  | ||||
|     public static class FileRead implements Callable<HashMap<ByteArrayWrapper, ResultRow>> { | ||||
|  | ||||
|         private final FileChannel fileChannel; | ||||
|         private long currentLocation; | ||||
|         private int bytesToRead; | ||||
|  | ||||
|         private final int hashCapacityPow2 = 18; | ||||
|         private final int hashCapacityMask = (1 << hashCapacityPow2) - 1; | ||||
|  | ||||
|         public FileRead(long startLocation, int bytesToRead, FileChannel fileChannel) { | ||||
|             this.currentLocation = startLocation; | ||||
|             this.bytesToRead = bytesToRead; | ||||
|             this.fileChannel = fileChannel; | ||||
|         } | ||||
|  | ||||
|         @Override | ||||
|         public HashMap<ByteArrayWrapper, ResultRow> call() throws IOException { | ||||
|             try { | ||||
|                 OpenHash openHash = new OpenHash(hashCapacityPow2); | ||||
|                 log("Reading the channel: " + currentLocation + ":" + bytesToRead); | ||||
|                 byte[] suffix = new byte[128]; | ||||
|                 if (currentLocation > 0) { | ||||
|                     toLineBegin(suffix); | ||||
|                 } | ||||
|                 while (bytesToRead > 0) { | ||||
|                     int bufferSize = Math.min(1 << 24, bytesToRead); | ||||
|                     MappedByteBuffer byteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, currentLocation, bufferSize); | ||||
|                     bytesToRead -= bufferSize; | ||||
|                     currentLocation += bufferSize; | ||||
|                     int suffixBytes = 0; | ||||
|                     if (currentLocation < fileChannel.size()) { | ||||
|                         suffixBytes = toLineBegin(suffix); | ||||
|                     } | ||||
|                     processChunk(byteBuffer, bufferSize, suffix, suffixBytes, openHash); | ||||
|                 } | ||||
|                 log("Done Reading the channel: " + currentLocation + ":" + bytesToRead); | ||||
|                 return openHash.toJavaHashMap(); | ||||
|             } | ||||
|             catch (Exception e) { | ||||
|                 e.printStackTrace(); | ||||
|                 throw e; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         byte getByte(long position) throws IOException { | ||||
|             MappedByteBuffer byteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, position, 1); | ||||
|             return byteBuffer.get(); | ||||
|         } | ||||
|  | ||||
|         int toLineBegin(byte[] suffix) throws IOException { | ||||
|             int bytesConsumed = 0; | ||||
|             if (getByte(currentLocation - 1) != LINE_SEPARATOR) { | ||||
|                 while (getByte(currentLocation) != LINE_SEPARATOR) { // Small bug here if last chunk is less than a line and has no '\n' at the end. Valid input should have '\n' for all rows. | ||||
|                     suffix[bytesConsumed++] = getByte(currentLocation); | ||||
|                     ++currentLocation; | ||||
|                     --bytesToRead; | ||||
|                 } | ||||
|                 ++currentLocation; | ||||
|                 --bytesToRead; | ||||
|             } | ||||
|             return bytesConsumed; | ||||
|         } | ||||
|  | ||||
|         void processChunk(MappedByteBuffer byteBuffer, int bufferSize, byte[] suffix, int suffixBytes, OpenHash result) { | ||||
|             int nameBegin = 0; | ||||
|             int nameEnd = -1; | ||||
|             int numberBegin = -1; | ||||
|             int currentHash = 0; | ||||
|             int currentMask = 0; | ||||
|             int nameHash = 0; | ||||
|             for (int currentPosition = 0; currentPosition < bufferSize; ++currentPosition) { | ||||
|                 byte nextByte = byteBuffer.get(currentPosition); | ||||
|                 if (nextByte == ';') { | ||||
|                     nameEnd = currentPosition; | ||||
|                     numberBegin = currentPosition + 1; | ||||
|                     nameHash = currentHash & hashCapacityMask; | ||||
|                 } | ||||
|                 else if (nextByte == LINE_SEPARATOR) { | ||||
|                     long value = getValue(byteBuffer, numberBegin, currentPosition); | ||||
|                     // log("Station name: '" + getStationName(byteBuffer, nameBegin, nameEnd) + "' value: " + value + " hash: " + nameHash); | ||||
|                     result.merge(byteBuffer, nameBegin, nameEnd, nameHash, value); | ||||
|                     nameBegin = currentPosition + 1; | ||||
|                     currentHash = 0; | ||||
|                     currentMask = 0; | ||||
|                 } | ||||
|                 else { | ||||
|                     currentHash += (nextByte << currentMask); | ||||
|                     currentMask = (currentMask + 1) & 3; | ||||
|                 } | ||||
|             } | ||||
|             if (nameBegin < bufferSize) { | ||||
|                 byte[] lastLine = new byte[bufferSize - nameBegin + suffixBytes]; | ||||
|                 byte[] prefix = new byte[bufferSize - nameBegin]; | ||||
|                 byteBuffer.slice(nameBegin, prefix.length).get(prefix, 0, prefix.length); | ||||
|                 System.arraycopy(prefix, 0, lastLine, 0, prefix.length); | ||||
|                 System.arraycopy(suffix, 0, lastLine, prefix.length, suffixBytes); | ||||
|                 processLastLine(lastLine, result); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         void processLastLine(byte[] lastLine, OpenHash result) { | ||||
|             int numberBegin = -1; | ||||
|             byte[] stationName = null; | ||||
|             for (int i = 0; i < lastLine.length; ++i) { | ||||
|                 if (lastLine[i] == ';') { | ||||
|                     stationName = new byte[i]; | ||||
|                     System.arraycopy(lastLine, 0, stationName, 0, stationName.length); | ||||
|                     numberBegin = i + 1; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             long value = getValue(lastLine, numberBegin); | ||||
|             // log("Station name: '" + new String(stationName, StandardCharsets.UTF_8) + "' value: " + value); | ||||
|             result.merge(stationName, value); | ||||
|         } | ||||
|  | ||||
|         long getValue(MappedByteBuffer byteBuffer, int startLocation, int endLocation) { | ||||
|             byte nextByte = byteBuffer.get(startLocation); | ||||
|             boolean negate = nextByte == '-'; | ||||
|             long result = negate ? 0 : nextByte - '0'; | ||||
|             for (int i = startLocation + 1; i < endLocation; ++i) { | ||||
|                 nextByte = byteBuffer.get(i); | ||||
|                 if (nextByte != '.') { | ||||
|                     result *= 10; | ||||
|                     result += nextByte - '0'; | ||||
|                 } | ||||
|             } | ||||
|             return negate ? -result : result; | ||||
|         } | ||||
|  | ||||
|         long getValue(byte[] lastLine, int startLocation) { | ||||
|             byte nextByte = lastLine[startLocation]; | ||||
|             boolean negate = nextByte == '-'; | ||||
|             long result = negate ? 0 : nextByte - '0'; | ||||
|             for (int i = startLocation + 1; i < lastLine.length; ++i) { | ||||
|                 nextByte = lastLine[i]; | ||||
|                 if (nextByte != '.') { | ||||
|                     result *= 10; | ||||
|                     result += nextByte - '0'; | ||||
|                 } | ||||
|             } | ||||
|             return negate ? -result : result; | ||||
|         } | ||||
|  | ||||
|         String getStationName(MappedByteBuffer byteBuffer, int from, int to) { | ||||
|             byte[] bytes = new byte[to - from]; | ||||
|             byteBuffer.slice(from, to - from).get(0, bytes); | ||||
|             return new String(bytes, StandardCharsets.UTF_8); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     public static void main(String[] args) throws Exception { | ||||
|         String fileName = "measurements.txt"; | ||||
|         if (args.length > 0 && args[0].length() > 0) { | ||||
|             fileName = args[0]; | ||||
|         } | ||||
|         log("InputFile: " + fileName); | ||||
|         FileInputStream fileInputStream = new FileInputStream(fileName); | ||||
|         int numThreads = 32; | ||||
|         if (args.length > 1) { | ||||
|             numThreads = Integer.parseInt(args[1]); | ||||
|         } | ||||
|         log("NumThreads: " + numThreads); | ||||
|         FileChannel channel = fileInputStream.getChannel(); | ||||
|         final long fileSize = channel.size(); | ||||
|         long remaining_size = fileSize; | ||||
|         long chunk_size = Math.min((fileSize + numThreads - 1) / numThreads, Integer.MAX_VALUE - 5); | ||||
|  | ||||
|         ExecutorService executor = Executors.newFixedThreadPool(numThreads); | ||||
|  | ||||
|         long startLocation = 0; | ||||
|         ArrayList<Future<HashMap<ByteArrayWrapper, ResultRow>>> results = new ArrayList<>(numThreads); | ||||
|         while (remaining_size > 0) { | ||||
|             long actualSize = Math.min(chunk_size, remaining_size); | ||||
|             results.add(executor.submit(new FileRead(startLocation, toIntExact(actualSize), channel))); | ||||
|             remaining_size -= actualSize; | ||||
|             startLocation += actualSize; | ||||
|         } | ||||
|         executor.shutdown(); | ||||
|  | ||||
|         // Wait for all threads to finish | ||||
|         while (!executor.isTerminated()) { | ||||
|             Thread.yield(); | ||||
|         } | ||||
|         log("Finished all threads"); | ||||
|         fileInputStream.close(); | ||||
|         HashMap<ByteArrayWrapper, ResultRow> result = new HashMap<>(20000); | ||||
|         for (var future : results) { | ||||
|             for (var entry : future.get().entrySet()) { | ||||
|                 result.merge(entry.getKey(), entry.getValue(), ResultRow::merge); | ||||
|             } | ||||
|         } | ||||
|         ResultRow[] finalResult = result.values().toArray(new ResultRow[0]); | ||||
|         for (var row : finalResult) { | ||||
|             row.toString(); | ||||
|         } | ||||
|         Arrays.sort(finalResult, Comparator.comparing(a -> a.stationString)); | ||||
|         System.out.println("{" + String.join(", ", Arrays.stream(finalResult).map(ResultRow::toString).toList()) + "}"); | ||||
|         log("All done!"); | ||||
|     } | ||||
|  | ||||
|     static void log(String message) { | ||||
|         // System.err.println(Instant.now() + "[" + Thread.currentThread().getName() + "]: " + message); | ||||
|     } | ||||
| } | ||||
		Reference in New Issue
	
	Block a user