Attempt nicer threading via streams and spliterators
This commit is contained in:
		
				
					committed by
					
						 Gunnar Morling
						Gunnar Morling
					
				
			
			
				
	
			
			
			
						parent
						
							b2cd84c6bc
						
					
				
				
					commit
					6aa63e1bd5
				
			| @@ -21,91 +21,68 @@ import java.nio.ByteBuffer; | |||||||
| import java.nio.channels.FileChannel; | import java.nio.channels.FileChannel; | ||||||
| import java.nio.charset.StandardCharsets; | import java.nio.charset.StandardCharsets; | ||||||
| import java.util.*; | import java.util.*; | ||||||
|  | import java.util.stream.Collectors; | ||||||
|  | import java.util.stream.StreamSupport; | ||||||
|  |  | ||||||
| public class CalculateAverage_palmr { | public class CalculateAverage_palmr { | ||||||
|  |  | ||||||
|     private static final String FILE = "./measurements.txt"; |     private static final String FILE = "./measurements.txt"; | ||||||
|     public static final int CHUNK_SIZE = 1024 * 1024 * 10; // Trial and error showed ~10MB to be a good size on our machine |     private static final int CHUNK_SIZE = 1024 * 1024 * 10; // Trial and error showed ~10MB to be a good size on our machine | ||||||
|     public static final int LITTLE_CHUNK_SIZE = 128; // Enough bytes to cover a station name and measurement value :fingers-crossed: |     private static final int STATION_NAME_BUFFER_SIZE = 50; | ||||||
|     public static final int STATION_NAME_BUFFER_SIZE = 50; |     private static final int THREAD_COUNT = Math.min(8, Runtime.getRuntime().availableProcessors()); | ||||||
|     public static final int THREAD_COUNT = Math.min(8, Runtime.getRuntime().availableProcessors()); |     private static final char SEPARATOR_CHAR = ';'; | ||||||
|  |     private static final char NEWLINE_CHAR = '\n'; | ||||||
|  |     private static final char MINUS_CHAR = '-'; | ||||||
|  |     private static final char DECIMAL_POINT_CHAR = '.'; | ||||||
|  |  | ||||||
|     public static void main(String[] args) throws IOException { |     public static void main(String[] args) throws IOException { | ||||||
|  |  | ||||||
|         @SuppressWarnings("resource") // It's faster to leak the file than be well-behaved |         @SuppressWarnings("resource") // It's faster to leak the file than be well-behaved | ||||||
|         RandomAccessFile file = new RandomAccessFile(FILE, "r"); |         final var file = new RandomAccessFile(FILE, "r"); | ||||||
|         FileChannel channel = file.getChannel(); |         final var channel = file.getChannel(); | ||||||
|         long fileSize = channel.size(); |  | ||||||
|  |  | ||||||
|         long threadChunk = fileSize / THREAD_COUNT; |         final TreeMap<String, MeasurementAggregator> results = StreamSupport.stream(ThreadChunk.chunk(file, THREAD_COUNT), true) | ||||||
|  |                 .map(chunk -> parseChunk(chunk, channel)) | ||||||
|         Thread[] threads = new Thread[THREAD_COUNT]; |                 .flatMap(bakm -> bakm.getAsUnorderedList().stream()) | ||||||
|         ByteArrayKeyedMap[] results = new ByteArrayKeyedMap[THREAD_COUNT]; |                 .collect(Collectors.toMap(m -> new String(m.stationNameBytes, StandardCharsets.UTF_8), m -> m, MeasurementAggregator::merge, TreeMap::new)); | ||||||
|         for (int i = 0; i < THREAD_COUNT; i++) { |         System.out.println(results); | ||||||
|             final int j = i; |  | ||||||
|             long startPoint = j * threadChunk; |  | ||||||
|             long endPoint = startPoint + threadChunk; |  | ||||||
|             Thread thread = new Thread(() -> { |  | ||||||
|                 try { |  | ||||||
|                     results[j] = readAndParse(channel, startPoint, endPoint, fileSize); |  | ||||||
|                 } |  | ||||||
|                 catch (Throwable t) { |  | ||||||
|                     System.err.println("It's broken :("); |  | ||||||
|                     // noinspection CallToPrintStackTrace |  | ||||||
|                     t.printStackTrace(); |  | ||||||
|                 } |  | ||||||
|             }); |  | ||||||
|             threads[i] = thread; |  | ||||||
|             thread.start(); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         final Map<String, MeasurementAggregator> finalAggregator = new TreeMap<>(); |  | ||||||
|  |  | ||||||
|         for (int i = 0; i < THREAD_COUNT; i++) { |  | ||||||
|             try { |  | ||||||
|                 threads[i].join(); |  | ||||||
|             } |  | ||||||
|             catch (InterruptedException e) { |  | ||||||
|                 throw new RuntimeException(e); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             results[i].getAsUnorderedList().forEach(v -> { |  | ||||||
|                 String stationName = new String(v.stationNameBytes, StandardCharsets.UTF_8); |  | ||||||
|                 finalAggregator.compute(stationName, (_, x) -> { |  | ||||||
|                     if (x == null) { |  | ||||||
|                         return v; |  | ||||||
|                     } |  | ||||||
|                     else { |  | ||||||
|                         x.count += v.count; |  | ||||||
|                         x.min = Math.min(x.min, v.min); |  | ||||||
|                         x.max = Math.max(x.max, v.max); |  | ||||||
|                         x.sum += v.sum; |  | ||||||
|                         return x; |  | ||||||
|                     } |  | ||||||
|                 }); |  | ||||||
|             }); |  | ||||||
|         } |  | ||||||
|         System.out.println(finalAggregator); |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     private static ByteArrayKeyedMap readAndParse(final FileChannel channel, |     private record ThreadChunk(long startPoint, long endPoint, long size) { | ||||||
|                                                   final long startPoint, |         public static Spliterator<CalculateAverage_palmr.ThreadChunk> chunk(final RandomAccessFile file, final int chunkCount) throws IOException { | ||||||
|                                                   final long endPoint, |             final var fileSize = file.length(); | ||||||
|                                                   final long fileSize) { |             final var idealChunkSize = fileSize / THREAD_COUNT; | ||||||
|         final State state = new State(); |             final var chunks = new CalculateAverage_palmr.ThreadChunk[chunkCount]; | ||||||
|  |  | ||||||
|         boolean skipFirstEntry = startPoint != 0; |             var startPoint = 0L; | ||||||
|  |             for (int i = 0; i < chunkCount; i++) { | ||||||
|  |                 var endPoint = Math.min(startPoint + idealChunkSize, fileSize); | ||||||
|  |                 file.seek(endPoint); | ||||||
|  |                 while (endPoint < fileSize && file.readByte() != NEWLINE_CHAR) { | ||||||
|  |                     endPoint++; | ||||||
|  |                 } | ||||||
|  |                 final var actualSize = endPoint - startPoint; | ||||||
|  |                 chunks[i] = new CalculateAverage_palmr.ThreadChunk(startPoint, endPoint, actualSize); | ||||||
|  |                 startPoint += actualSize; | ||||||
|  |             } | ||||||
|  |  | ||||||
|         long offset = startPoint; |             return Spliterators.spliterator(chunks, | ||||||
|         while (offset < endPoint) { |                     Spliterator.ORDERED | | ||||||
|             parseData(channel, state, offset, Math.min(CHUNK_SIZE, fileSize - offset), false, skipFirstEntry); |                             Spliterator.DISTINCT | | ||||||
|             skipFirstEntry = false; |                             Spliterator.SORTED | | ||||||
|             offset += CHUNK_SIZE; |                             Spliterator.NONNULL | | ||||||
|  |                             Spliterator.IMMUTABLE | | ||||||
|  |                             Spliterator.CONCURRENT | ||||||
|  |             ); | ||||||
|         } |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|         if (offset < fileSize) { |     private static ByteArrayKeyedMap parseChunk(ThreadChunk chunk, FileChannel channel) { | ||||||
|             // Make sure we finish reading any partially read entry by going a little in to the next chunk, stopping at the first newline |         final var state = new State(); | ||||||
|             parseData(channel, state, offset, Math.min(LITTLE_CHUNK_SIZE, fileSize - offset), true, false); |  | ||||||
|  |         var offset = chunk.startPoint; | ||||||
|  |         while (offset < chunk.endPoint) { | ||||||
|  |             parseData(channel, state, offset, Math.min(CHUNK_SIZE, chunk.endPoint - offset)); | ||||||
|  |             offset += CHUNK_SIZE; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         return state.aggregators; |         return state.aggregators; | ||||||
| @@ -114,69 +91,48 @@ public class CalculateAverage_palmr { | |||||||
|     private static void parseData(final FileChannel channel, |     private static void parseData(final FileChannel channel, | ||||||
|                                   final State state, |                                   final State state, | ||||||
|                                   final long offset, |                                   final long offset, | ||||||
|                                   final long bufferSize, |                                   final long bufferSize) { | ||||||
|                                   final boolean stopAtNewline, |         final ByteBuffer byteBuffer; | ||||||
|                                   final boolean skipFirstEntry) { |  | ||||||
|         ByteBuffer byteBuffer; |  | ||||||
|         try { |         try { | ||||||
|             byteBuffer = channel.map(FileChannel.MapMode.READ_ONLY, offset, bufferSize); |             byteBuffer = channel.map(FileChannel.MapMode.READ_ONLY, offset, bufferSize); | ||||||
|         } |  | ||||||
|         catch (IOException e) { |             while (byteBuffer.hasRemaining()) { | ||||||
|  |                 final var currentChar = byteBuffer.get(); | ||||||
|  |  | ||||||
|  |                 if (currentChar == SEPARATOR_CHAR) { | ||||||
|  |                     state.parsingValue = true; | ||||||
|  |                 } else if (currentChar == NEWLINE_CHAR) { | ||||||
|  |                     if (state.stationPointerEnd != 0) { | ||||||
|  |                         final var value = state.measurementValue * state.exponent; | ||||||
|  |  | ||||||
|  |                         MeasurementAggregator aggregator = state.aggregators.computeIfAbsent(state.stationBuffer, state.stationPointerEnd, state.signedHashCode); | ||||||
|  |                         aggregator.count++; | ||||||
|  |                         aggregator.min = Math.min(aggregator.min, value); | ||||||
|  |                         aggregator.max = Math.max(aggregator.max, value); | ||||||
|  |                         aggregator.sum += value; | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     // reset | ||||||
|  |                     state.reset(); | ||||||
|  |                 } else { | ||||||
|  |                     if (!state.parsingValue) { | ||||||
|  |                         state.stationBuffer[state.stationPointerEnd++] = currentChar; | ||||||
|  |                         state.signedHashCode = 31 * state.signedHashCode + (currentChar & 0xff); | ||||||
|  |                     } else { | ||||||
|  |                         if (currentChar == MINUS_CHAR) { | ||||||
|  |                             state.exponent = -0.1; | ||||||
|  |                         } else if (currentChar != DECIMAL_POINT_CHAR) { | ||||||
|  |                             state.measurementValue = state.measurementValue * 10 + (currentChar - '0'); | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } catch (IOException e) { | ||||||
|             throw new RuntimeException(e); |             throw new RuntimeException(e); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         boolean isSkippingToFirstCleanEntry = skipFirstEntry; |  | ||||||
|  |  | ||||||
|         while (byteBuffer.hasRemaining()) { |  | ||||||
|             byte currentChar = byteBuffer.get(); |  | ||||||
|  |  | ||||||
|             if (isSkippingToFirstCleanEntry) { |  | ||||||
|                 if (currentChar == '\n') { |  | ||||||
|                     isSkippingToFirstCleanEntry = false; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             if (currentChar == ';') { |  | ||||||
|                 state.parsingValue = true; |  | ||||||
|             } |  | ||||||
|             else if (currentChar == '\n') { |  | ||||||
|                 if (state.stationPointerEnd != 0) { |  | ||||||
|                     double value = state.measurementValue * state.exponent; |  | ||||||
|  |  | ||||||
|                     MeasurementAggregator aggregator = state.aggregators.computeIfAbsent(state.stationBuffer, state.stationPointerEnd, state.signedHashCode); |  | ||||||
|                     aggregator.count++; |  | ||||||
|                     aggregator.min = Math.min(aggregator.min, value); |  | ||||||
|                     aggregator.max = Math.max(aggregator.max, value); |  | ||||||
|                     aggregator.sum += value; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 if (stopAtNewline) { |  | ||||||
|                     return; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 // reset |  | ||||||
|                 state.reset(); |  | ||||||
|             } |  | ||||||
|             else { |  | ||||||
|                 if (!state.parsingValue) { |  | ||||||
|                     state.stationBuffer[state.stationPointerEnd++] = currentChar; |  | ||||||
|                     state.signedHashCode = 31 * state.signedHashCode + (currentChar & 0xff); |  | ||||||
|                 } |  | ||||||
|                 else { |  | ||||||
|                     if (currentChar == '-') { |  | ||||||
|                         state.exponent = -0.1; |  | ||||||
|                     } |  | ||||||
|                     else if (currentChar != '.') { |  | ||||||
|                         state.measurementValue = state.measurementValue * 10 + (currentChar - '0'); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     static final class State { |     private static final class State { | ||||||
|         ByteArrayKeyedMap aggregators = new ByteArrayKeyedMap(); |         ByteArrayKeyedMap aggregators = new ByteArrayKeyedMap(); | ||||||
|         boolean parsingValue = false; |         boolean parsingValue = false; | ||||||
|         byte[] stationBuffer = new byte[STATION_NAME_BUFFER_SIZE]; |         byte[] stationBuffer = new byte[STATION_NAME_BUFFER_SIZE]; | ||||||
| @@ -208,37 +164,51 @@ public class CalculateAverage_palmr { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         public String toString() { |         public String toString() { | ||||||
|             return round(min) + "/" + round(sum / count) + "/" + round(max); |             return STR."\{round(min)}/\{round(sum / count)}/\{round(max)}"; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         private double round(double value) { |         private double round(final double value) { | ||||||
|             return Math.round(value * 10.0) / 10.0; |             return Math.round(value * 10.0) / 10.0; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         private MeasurementAggregator merge(final MeasurementAggregator b) { | ||||||
|  |             this.count += b.count; | ||||||
|  |             this.min = Math.min(this.min, b.min); | ||||||
|  |             this.max = Math.max(this.max, b.max); | ||||||
|  |             this.sum += b.sum; | ||||||
|  |             return this; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * Very basic hash table implementation, only implementing computeIfAbsent since that's all the code needs. | ||||||
|  |      * It's sized to give minimal collisions with the example test set. this may not hold true if the stations list | ||||||
|  |      * changes, but it should still perform fairly well. | ||||||
|  |      * It uses Open Addressing, meaning it's just one array, rather Separate Chaining which is what the default java HashMap uses. | ||||||
|  |      * IT also uses Linear probing for collision resolution, which given the minimal collision count should hold up well. | ||||||
|  |      */ | ||||||
|     private static class ByteArrayKeyedMap { |     private static class ByteArrayKeyedMap { | ||||||
|         private final int BUCKET_COUNT = 0xFFF; // 413 unique stations in the data set, & 0xFFF ~= 399 (only 14 collisions (given our hashcode implementation)) |         private final int BUCKET_COUNT = 0xFFF; // 413 unique stations in the data set, & 0xFFF ~= 399 (only 14 collisions (given our hashcode implementation)) | ||||||
|         private final MeasurementAggregator[] buckets = new MeasurementAggregator[BUCKET_COUNT + 1]; |         private final MeasurementAggregator[] buckets = new MeasurementAggregator[BUCKET_COUNT + 1]; | ||||||
|         private final List<MeasurementAggregator> compactUnorderedBuckets = new ArrayList<>(413); |         private final List<MeasurementAggregator> compactUnorderedBuckets = new ArrayList<>(413); | ||||||
|  |  | ||||||
|         public MeasurementAggregator computeIfAbsent(final byte[] key, final int keyLength, final int keyHashCode) { |         public MeasurementAggregator computeIfAbsent(final byte[] key, final int keyLength, final int keyHashCode) { | ||||||
|             int index = keyHashCode & BUCKET_COUNT; |             var index = keyHashCode & BUCKET_COUNT; | ||||||
|  |  | ||||||
|             while (true) { |             while (true) { | ||||||
|                 MeasurementAggregator maybe = buckets[index]; |                 MeasurementAggregator maybe = buckets[index]; | ||||||
|                 if (maybe == null) { |                 if (maybe != null) { | ||||||
|                     final byte[] copiedKey = Arrays.copyOf(key, keyLength); |  | ||||||
|                     MeasurementAggregator measurementAggregator = new MeasurementAggregator(copiedKey, keyHashCode); |  | ||||||
|                     buckets[index] = measurementAggregator; |  | ||||||
|                     compactUnorderedBuckets.add(measurementAggregator); |  | ||||||
|                     return measurementAggregator; |  | ||||||
|                 } |  | ||||||
|                 else { |  | ||||||
|                     if (Arrays.equals(key, 0, keyLength, maybe.stationNameBytes, 0, maybe.stationNameBytes.length)) { |                     if (Arrays.equals(key, 0, keyLength, maybe.stationNameBytes, 0, maybe.stationNameBytes.length)) { | ||||||
|                         return maybe; |                         return maybe; | ||||||
|                     } |                     } | ||||||
|                     index++; |                     index++; | ||||||
|                     index &= BUCKET_COUNT; |                     index &= BUCKET_COUNT; | ||||||
|  |                 } else { | ||||||
|  |                     final var copiedKey = Arrays.copyOf(key, keyLength); | ||||||
|  |                     MeasurementAggregator measurementAggregator = new MeasurementAggregator(copiedKey, keyHashCode); | ||||||
|  |                     buckets[index] = measurementAggregator; | ||||||
|  |                     compactUnorderedBuckets.add(measurementAggregator); | ||||||
|  |                     return measurementAggregator; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user