isolgpus: submission 1
* isolgpus: submission 1 * isolgpus: fix min value bug (breaks if a negative temperature never appears) * isolgpus: remove unused collector * isolgpus: fix split on chunk bug * isolgpus: change name equality algo to a cheaper check. * isolgpus: fix chunking state to cope with last byte of last chunk * isolgpus: hash as we go, instead of at the end * isolgpus: adjust thread count to core count * isolgpus: change cores to 8 statically --------- Co-authored-by: Jamie Stansfield <jalstansfield@gmail.com>
This commit is contained in:
		
							
								
								
									
										20
									
								
								calculate_average_isolgpus.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										20
									
								
								calculate_average_isolgpus.sh
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| #!/bin/sh | ||||
| # | ||||
| #  Copyright 2023 The original authors | ||||
| # | ||||
| #  Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| #  you may not use this file except in compliance with the License. | ||||
| #  You may obtain a copy of the License at | ||||
| # | ||||
| #      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| #  Unless required by applicable law or agreed to in writing, software | ||||
| #  distributed under the License is distributed on an "AS IS" BASIS, | ||||
| #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| #  See the License for the specific language governing permissions and | ||||
| #  limitations under the License. | ||||
| # | ||||
|  | ||||
|  | ||||
| JAVA_OPTS="--enable-preview" | ||||
| time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_isolgpus | ||||
							
								
								
									
										293
									
								
								src/main/java/dev/morling/onebrc/CalculateAverage_isolgpus.java
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										293
									
								
								src/main/java/dev/morling/onebrc/CalculateAverage_isolgpus.java
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,293 @@ | ||||
| /* | ||||
|  *  Copyright 2023 The original authors | ||||
|  * | ||||
|  *  Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  *  you may not use this file except in compliance with the License. | ||||
|  *  You may obtain a copy of the License at | ||||
|  * | ||||
|  *      http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  *  Unless required by applicable law or agreed to in writing, software | ||||
|  *  distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  *  See the License for the specific language governing permissions and | ||||
|  *  limitations under the License. | ||||
|  */ | ||||
| package dev.morling.onebrc; | ||||
|  | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.io.RandomAccessFile; | ||||
| import java.math.BigDecimal; | ||||
| import java.math.RoundingMode; | ||||
| import java.nio.BufferUnderflowException; | ||||
| import java.nio.MappedByteBuffer; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.file.Paths; | ||||
| import java.util.*; | ||||
| import java.util.concurrent.*; | ||||
| import java.util.stream.Collectors; | ||||
|  | ||||
| public class CalculateAverage_isolgpus { | ||||
|  | ||||
|     public static final int HISTOGRAMS_LENGTH = 1024 * 32; | ||||
|     public static final int HISTOGRAMS_MASK = HISTOGRAMS_LENGTH - 1; | ||||
|     public static final int THREAD_COUNT = 8; | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|     public static final byte SEPERATOR = 59; | ||||
|     public static final byte OFFSET = 48; | ||||
|     public static final byte NEGATIVE = 45; | ||||
|     public static final byte DECIMAL_POINT = 46; | ||||
|     public static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 100; // bit of wiggle room | ||||
|     public static final byte NEW_LINE = 10; | ||||
|  | ||||
|     public static void main(String[] args) throws IOException, InterruptedException, ExecutionException { | ||||
|         ExecutorService executorService = Executors.newFixedThreadPool(THREAD_COUNT); | ||||
|  | ||||
|         File file = Paths.get(FILE).toFile(); | ||||
|         long length = file.length(); | ||||
|         long chunksCount = Math.max(THREAD_COUNT, (int) Math.ceil(length / (double) MAX_CHUNK_SIZE)); | ||||
|  | ||||
|         long estimatedChunkSize = length / chunksCount; | ||||
|  | ||||
|         FileChannel channel = new RandomAccessFile(file, "r").getChannel(); | ||||
|  | ||||
|         List<Future<MeasurementCollector[]>> futures = new ArrayList<>(); | ||||
|         for (int i = 0; i < chunksCount; i++) { | ||||
|             int finalI = i; | ||||
|             futures.add(executorService.submit(() -> handleChunk(channel, estimatedChunkSize * finalI, estimatedChunkSize, length))); | ||||
|         } | ||||
|  | ||||
|         List<MeasurementCollector[]> measurementCollectors = new ArrayList<>(); | ||||
|         for (Future<MeasurementCollector[]> result : futures) { | ||||
|             measurementCollectors.add(result.get()); | ||||
|         } | ||||
|         executorService.shutdown(); | ||||
|  | ||||
|         Map<String, MeasurementCollector> measurementCollectorsByCity = mergeMeasurements(measurementCollectors); | ||||
|         List<MeasurementResult> results = measurementCollectorsByCity.values().stream().map(MeasurementResult::from).toList(); | ||||
|  | ||||
|         System.out.println("{" + results.stream().map(MeasurementResult::toString).collect(Collectors.joining(", ")) + "}"); | ||||
|  | ||||
|     } | ||||
|  | ||||
|     private static Map<String, MeasurementCollector> mergeMeasurements(List<MeasurementCollector[]> resultsFromAllChunk) { | ||||
|         Map<String, MeasurementCollector> mergedResults = new TreeMap<>(Comparator.naturalOrder()); | ||||
|  | ||||
|         for (int i = 0; i < HISTOGRAMS_LENGTH; i++) { | ||||
|             for (MeasurementCollector[] resultFromSpecificChunk : resultsFromAllChunk) { | ||||
|                 MeasurementCollector measurementCollectorFromChunk = resultFromSpecificChunk[i]; | ||||
|                 while (measurementCollectorFromChunk != null) { | ||||
|                     MeasurementCollector currentMergedResult = mergedResults.get(new String(measurementCollectorFromChunk.name)); | ||||
|                     if (currentMergedResult == null) { | ||||
|                         currentMergedResult = new MeasurementCollector(measurementCollectorFromChunk.name); | ||||
|                         mergedResults.put(new String(currentMergedResult.name), currentMergedResult); | ||||
|                     } | ||||
|                     currentMergedResult.merge(measurementCollectorFromChunk); | ||||
|                     measurementCollectorFromChunk = measurementCollectorFromChunk.link; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         return mergedResults; | ||||
|     } | ||||
|  | ||||
|     // ----n--- | ||||
|     private static MeasurementCollector[] handleChunk(FileChannel channel, long estimatedStart, long lengthOfChunk, long maxLengthOfFile) throws IOException { | ||||
|         // -1 to see if we're starting on a brand new message | ||||
|         // +200 for wiggle room to finish the final message | ||||
|  | ||||
|         long seekStart = Math.max(estimatedStart - 1, 0); | ||||
|         long length = Math.min(lengthOfChunk + 200, maxLengthOfFile - seekStart); | ||||
|  | ||||
|         MappedByteBuffer r = channel.map(FileChannel.MapMode.READ_ONLY, seekStart, length); | ||||
|  | ||||
|         byte[] nameBuffer = new byte[100]; | ||||
|         boolean isNegative; | ||||
|         byte[] valueBuffer = new byte[3]; | ||||
|         MeasurementCollector[] measurementCollectors = new MeasurementCollector[HISTOGRAMS_LENGTH]; | ||||
|         int valueIndex = 0; | ||||
|         int nameBufferIndex = 0; | ||||
|         int nameSum = 0; | ||||
|         boolean parsingName = true; | ||||
|         long i = 0; | ||||
|         int hashResult = 0; | ||||
|  | ||||
|         // seek to the start of the next message | ||||
|         if (estimatedStart != 0) { | ||||
|             while (r.get() != NEW_LINE) { | ||||
|                 i++; | ||||
|             } | ||||
|             i++; | ||||
|         } | ||||
|  | ||||
|         try { | ||||
|  | ||||
|             while (i <= lengthOfChunk || !parsingName) { | ||||
|                 byte aChar; | ||||
|                 if (parsingName) { | ||||
|  | ||||
|                     while ((aChar = r.get()) != SEPERATOR) { | ||||
|                         nameBuffer[nameBufferIndex++] = aChar; | ||||
|                         nameSum += aChar; | ||||
|                         hashResult = 31 * hashResult + aChar; | ||||
|                     } | ||||
|                     parsingName = false; | ||||
|                     i += nameBufferIndex + 1; | ||||
|                 } | ||||
|                 else { | ||||
|                     isNegative = (aChar = r.get()) == NEGATIVE; | ||||
|                     valueIndex = readNumber(isNegative, valueBuffer, valueIndex, aChar, r); | ||||
|  | ||||
|                     byte decimalValue = r.get(); | ||||
|  | ||||
|                     int value = resolveValue(valueIndex, valueBuffer, decimalValue, isNegative); | ||||
|                     // new line character | ||||
|                     r.get(); | ||||
|  | ||||
|                     MeasurementCollector measurementCollector = resolveMeasurementCollector(measurementCollectors, hashResult, nameBuffer, nameBufferIndex, nameSum); | ||||
|  | ||||
|                     measurementCollector.feed(value); | ||||
|                     i += valueIndex + (isNegative ? 4 : 3); | ||||
|                     valueIndex = 0; | ||||
|                     nameBufferIndex = 0; | ||||
|                     nameSum = 0; | ||||
|                     parsingName = true; | ||||
|                     hashResult = 0; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|         } | ||||
|         catch (BufferUnderflowException e) { | ||||
|             if (i != maxLengthOfFile - seekStart) { | ||||
|                 e.printStackTrace(); | ||||
|                 throw new RuntimeException(e); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         return measurementCollectors; | ||||
|     } | ||||
|  | ||||
|     private static MeasurementCollector resolveMeasurementCollector(MeasurementCollector[] measurementCollectors, int hash, byte[] nameBuffer, int nameBufferIndex, | ||||
|                                                                     int nameSum) { | ||||
|         MeasurementCollector measurementCollector = measurementCollectors[hash & HISTOGRAMS_MASK]; | ||||
|         if (measurementCollector == null) { | ||||
|             measurementCollector = new MeasurementCollector(Arrays.copyOf(nameBuffer, nameBufferIndex)); | ||||
|             measurementCollectors[hash & HISTOGRAMS_MASK] = measurementCollector; | ||||
|         } | ||||
|         else { | ||||
|             // collision unhappy path, try to avoid | ||||
|             while (!nameEquals(measurementCollector.name, measurementCollector.nameSum, nameSum, nameBufferIndex)) { | ||||
|                 if (measurementCollector.link == null) { | ||||
|                     measurementCollector.link = new MeasurementCollector(Arrays.copyOf(nameBuffer, nameBufferIndex)); | ||||
|                     measurementCollector = measurementCollector.link; | ||||
|                     break; | ||||
|                 } | ||||
|                 else { | ||||
|                     measurementCollector = measurementCollector.link; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|         } | ||||
|         return measurementCollector; | ||||
|     } | ||||
|  | ||||
|     private static boolean nameEquals(byte[] existingName, int existingNameSum, int incomingNameSum, int nameBufferIndex) { | ||||
|  | ||||
|         if (existingName.length != nameBufferIndex) { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         return incomingNameSum == existingNameSum; | ||||
|     } | ||||
|  | ||||
|     private static int resolveValue(int valueIndex, byte[] valueBuffer, byte decimalValue, boolean isNegative) { | ||||
|         int value; | ||||
|         if (valueIndex == 1) { | ||||
|             value = ((valueBuffer[0] - OFFSET) * 10) + (decimalValue - OFFSET); | ||||
|         } | ||||
|         else // it's 2 digits | ||||
|         { | ||||
|             value = ((valueBuffer[0] - OFFSET) * 100) + ((valueBuffer[1] - OFFSET) * 10) + (decimalValue - OFFSET); | ||||
|         } | ||||
|  | ||||
|         if (isNegative) { | ||||
|             value = Math.negateExact(value); | ||||
|         } | ||||
|         return value; | ||||
|     } | ||||
|  | ||||
|     private static int readNumber(boolean isNegative, byte[] valueBuffer, int valueIndex, byte aChar, MappedByteBuffer r) { | ||||
|         if (!isNegative) { | ||||
|             valueBuffer[valueIndex++] = aChar; | ||||
|         } | ||||
|  | ||||
|         // maybe one or two more | ||||
|         while ((aChar = r.get()) != DECIMAL_POINT) { | ||||
|             valueBuffer[valueIndex++] = aChar; | ||||
|         } | ||||
|         return valueIndex; | ||||
|     } | ||||
|  | ||||
|     private static class MeasurementCollector { | ||||
|         private final byte[] name; | ||||
|         private final int nameSum; | ||||
|         public MeasurementCollector link; | ||||
|         private long sum; | ||||
|         private int count; | ||||
|         private int min = Integer.MAX_VALUE; | ||||
|         private int max = Integer.MIN_VALUE; | ||||
|  | ||||
|         public MeasurementCollector(byte[] name) { | ||||
|  | ||||
|             this.name = name; | ||||
|             int nameSum = 0; | ||||
|             for (int i = 0; i < name.length; i++) { | ||||
|                 nameSum += name[i]; | ||||
|             } | ||||
|             this.nameSum = nameSum; | ||||
|         } | ||||
|  | ||||
|         public void feed(int value) { | ||||
|             sum += value; | ||||
|             count++; | ||||
|             min = Math.min(value, min); | ||||
|             max = Math.max(value, max); | ||||
|         } | ||||
|  | ||||
|         public void merge(MeasurementCollector measurementCollector) { | ||||
|             this.sum += measurementCollector.sum; | ||||
|             this.count += measurementCollector.count; | ||||
|             this.min = Math.min(measurementCollector.min, this.min); | ||||
|             this.max = Math.max(measurementCollector.max, this.max); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static class MeasurementResult { | ||||
|         private final String name; | ||||
|         private final double mean; | ||||
|         private final BigDecimal max; | ||||
|         private final BigDecimal min; | ||||
|  | ||||
|         public MeasurementResult(String name, double mean, BigDecimal max, BigDecimal min) { | ||||
|  | ||||
|             this.name = name; | ||||
|             this.mean = mean; | ||||
|             this.max = max; | ||||
|             this.min = min; | ||||
|         } | ||||
|  | ||||
|         @Override | ||||
|         public String toString() { | ||||
|             // Abha=-24.9/18.0/61.7 | ||||
|             return name + "=" + min + "/" + mean + "/" + max; | ||||
|         } | ||||
|  | ||||
|         public static MeasurementResult from(MeasurementCollector mc) { | ||||
|             double mean = Math.round((double) mc.sum / (double) mc.count) / 10d; | ||||
|             BigDecimal max = BigDecimal.valueOf(mc.max).divide(BigDecimal.TEN, 1, RoundingMode.HALF_UP); | ||||
|             BigDecimal min = BigDecimal.valueOf(mc.min).divide(BigDecimal.TEN, 1, RoundingMode.HALF_UP); | ||||
|             return new MeasurementResult(new String(mc.name), mean, max, min); | ||||
|         } | ||||
|     } | ||||
| } | ||||
		Reference in New Issue
	
	Block a user