Sixth attempt CalculateAverage_zerninv.java (#407)
* rethink chunking * fix typo
This commit is contained in:
		| @@ -25,14 +25,15 @@ import java.nio.channels.FileChannel; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.nio.file.Path; | ||||
| import java.nio.file.StandardOpenOption; | ||||
| import java.util.*; | ||||
| import java.util.concurrent.ExecutionException; | ||||
| import java.util.concurrent.Executors; | ||||
| import java.util.concurrent.Future; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| import java.util.TreeMap; | ||||
|  | ||||
| public class CalculateAverage_zerninv { | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|     private static final int MIN_FILE_SIZE = 1024 * 1024 * 16; | ||||
|     private static final int L3_CACHE_SIZE = 128 * 1024 * 1024; | ||||
|     private static final int CORES = Runtime.getRuntime().availableProcessors(); | ||||
|     private static final int CHUNK_SIZE = (L3_CACHE_SIZE - MeasurementContainer.SIZE * MeasurementContainer.ENTRY_SIZE * CORES) / CORES - 1024 * CORES; | ||||
|  | ||||
|     // #.## | ||||
|     private static final int THREE_DIGITS_MASK = 0x2e0000; | ||||
| @@ -48,47 +49,48 @@ public class CalculateAverage_zerninv { | ||||
|  | ||||
|     private static final Unsafe UNSAFE = initUnsafe(); | ||||
|  | ||||
|     public static void main(String[] args) throws IOException { | ||||
|         var results = new HashMap<String, MeasurementAggregation>(); | ||||
|     public static void main(String[] args) throws IOException, InterruptedException { | ||||
|         try (var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { | ||||
|             var fileSize = channel.size(); | ||||
|             var memorySegment = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()); | ||||
|             long address = memorySegment.address(); | ||||
|             var cores = Runtime.getRuntime().availableProcessors(); | ||||
|             var minChunkSize = fileSize < MIN_FILE_SIZE ? fileSize : fileSize / cores; | ||||
|             var chunks = splitByChunks(address, address + fileSize, minChunkSize); | ||||
|             var minChunkSize = Math.min(fileSize, CHUNK_SIZE); | ||||
|  | ||||
|             var executor = Executors.newFixedThreadPool(cores); | ||||
|             List<Future<Map<String, MeasurementAggregation>>> fResults = new ArrayList<>(); | ||||
|             for (int i = 1; i < chunks.size(); i++) { | ||||
|                 final long prev = chunks.get(i - 1); | ||||
|                 final long curr = chunks.get(i); | ||||
|                 fResults.add(executor.submit(() -> calcForChunk(prev, curr))); | ||||
|             var tasks = new TaskThread[CORES]; | ||||
|             for (int i = 0; i < tasks.length; i++) { | ||||
|                 tasks[i] = new TaskThread(new MeasurementContainer(), (int) (fileSize / minChunkSize / CORES + 1)); | ||||
|             } | ||||
|  | ||||
|             fResults.forEach(f -> { | ||||
|                 try { | ||||
|                     f.get().forEach((key, value) -> { | ||||
|                         var result = results.get(key); | ||||
|                         if (result != null) { | ||||
|                             result.merge(value); | ||||
|                         } | ||||
|                         else { | ||||
|                             results.put(key, value); | ||||
|                         } | ||||
|                     }); | ||||
|                 } | ||||
|                 catch (InterruptedException | ExecutionException e) { | ||||
|                     e.printStackTrace(); | ||||
|                 } | ||||
|             }); | ||||
|             executor.shutdown(); | ||||
|         } | ||||
|             var memorySegment = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()); | ||||
|             var address = memorySegment.address(); | ||||
|             var chunks = splitByChunks(address, address + fileSize, minChunkSize); | ||||
|             for (int i = 0; i < chunks.size() - 1; i++) { | ||||
|                 var task = tasks[i % CORES]; | ||||
|                 task.addChunk(chunks.get(i), chunks.get(i + 1)); | ||||
|             } | ||||
|  | ||||
|         var bos = new BufferedOutputStream(System.out); | ||||
|         bos.write(new TreeMap<>(results).toString().getBytes(StandardCharsets.UTF_8)); | ||||
|         bos.write('\n'); | ||||
|         bos.flush(); | ||||
|             for (var task : tasks) { | ||||
|                 task.start(); | ||||
|             } | ||||
|  | ||||
|             var results = new TreeMap<String, TemperatureAggregation>(); | ||||
|             for (var task : tasks) { | ||||
|                 task.join(); | ||||
|                 task.measurements() | ||||
|                         .forEach(measurement -> { | ||||
|                             var aggr = results.get(measurement.station()); | ||||
|                             if (aggr == null) { | ||||
|                                 results.put(measurement.station(), measurement.aggregation()); | ||||
|                             } | ||||
|                             else { | ||||
|                                 aggr.merge(measurement.aggregation()); | ||||
|                             } | ||||
|                         }); | ||||
|             } | ||||
|  | ||||
|             var bos = new BufferedOutputStream(System.out); | ||||
|             bos.write(new TreeMap<>(results).toString().getBytes(StandardCharsets.UTF_8)); | ||||
|             bos.write('\n'); | ||||
|             bos.flush(); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static Unsafe initUnsafe() { | ||||
| @@ -103,7 +105,7 @@ public class CalculateAverage_zerninv { | ||||
|     } | ||||
|  | ||||
|     private static List<Long> splitByChunks(long address, long end, long minChunkSize) { | ||||
|         List<Long> result = new ArrayList<>(); | ||||
|         List<Long> result = new ArrayList<>((int) ((end - address) / minChunkSize + 1)); | ||||
|         result.add(address); | ||||
|         while (address < end) { | ||||
|             address += Math.min(end - address, minChunkSize); | ||||
| @@ -114,60 +116,20 @@ public class CalculateAverage_zerninv { | ||||
|         return result; | ||||
|     } | ||||
|  | ||||
|     private static Map<String, MeasurementAggregation> calcForChunk(long offset, long end) { | ||||
|         var results = new MeasurementContainer(); | ||||
|  | ||||
|         long cityOffset; | ||||
|         int hashCode, temperature, word; | ||||
|         byte cityNameSize, b; | ||||
|  | ||||
|         while (offset < end) { | ||||
|             cityOffset = offset; | ||||
|             hashCode = 0; | ||||
|             while ((b = UNSAFE.getByte(offset++)) != DELIMITER) { | ||||
|                 hashCode = hashCode * 31 + b; | ||||
|             } | ||||
|             cityNameSize = (byte) (offset - cityOffset - 1); | ||||
|  | ||||
|             word = UNSAFE.getInt(offset); | ||||
|             offset += 4; | ||||
|  | ||||
|             if ((word & TWO_NEGATIVE_DIGITS_MASK) == TWO_NEGATIVE_DIGITS_MASK) { | ||||
|                 word >>>= 8; | ||||
|                 temperature = ZERO * 11 - ((word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK)); | ||||
|             } | ||||
|             else if ((word & THREE_DIGITS_MASK) == THREE_DIGITS_MASK) { | ||||
|                 temperature = (word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK) - ZERO * 111; | ||||
|             } | ||||
|             else if ((word & TWO_DIGITS_MASK) == TWO_DIGITS_MASK) { | ||||
|                 temperature = (word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK) - ZERO * 11; | ||||
|                 offset--; | ||||
|             } | ||||
|             else { | ||||
|                 // #.##- | ||||
|                 word = (word >>> 8) | (UNSAFE.getByte(offset++) << 24); | ||||
|                 temperature = ZERO * 111 - ((word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK)); | ||||
|             } | ||||
|             offset++; | ||||
|             results.put(cityOffset, cityNameSize, hashCode, (short) temperature); | ||||
|         } | ||||
|         return results.toStringMap(); | ||||
|     } | ||||
|  | ||||
|     private static final class MeasurementAggregation { | ||||
|     private static final class TemperatureAggregation { | ||||
|         private long sum; | ||||
|         private int count; | ||||
|         private short min; | ||||
|         private short max; | ||||
|  | ||||
|         public MeasurementAggregation(long sum, int count, short min, short max) { | ||||
|         public TemperatureAggregation(long sum, int count, short min, short max) { | ||||
|             this.sum = sum; | ||||
|             this.count = count; | ||||
|             this.min = min; | ||||
|             this.max = max; | ||||
|         } | ||||
|  | ||||
|         public void merge(MeasurementAggregation o) { | ||||
|         public void merge(TemperatureAggregation o) { | ||||
|             if (o == null) { | ||||
|                 return; | ||||
|             } | ||||
| @@ -183,6 +145,9 @@ public class CalculateAverage_zerninv { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private record Measurement(String station, TemperatureAggregation aggregation) { | ||||
|     } | ||||
|  | ||||
|     private static final class MeasurementContainer { | ||||
|         private static final int SIZE = 1024 * 16; | ||||
|  | ||||
| @@ -235,26 +200,26 @@ public class CalculateAverage_zerninv { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public Map<String, MeasurementAggregation> toStringMap() { | ||||
|             var result = new HashMap<String, MeasurementAggregation>(); | ||||
|         public List<Measurement> measurements() { | ||||
|             var result = new ArrayList<Measurement>(1000); | ||||
|             int count; | ||||
|             for (int i = 0; i < SIZE; i++) { | ||||
|                 long ptr = this.address + i * ENTRY_SIZE; | ||||
|                 count = UNSAFE.getInt(ptr + COUNT_OFFSET); | ||||
|                 if (count != 0) { | ||||
|                     var measurements = new MeasurementAggregation( | ||||
|                     var measurements = new TemperatureAggregation( | ||||
|                             UNSAFE.getLong(ptr + SUM_OFFSET), | ||||
|                             count, | ||||
|                             UNSAFE.getShort(ptr + MIN_OFFSET), | ||||
|                             UNSAFE.getShort(ptr + MAX_OFFSET)); | ||||
|                     var key = createString(UNSAFE.getLong(ptr + ADDRESS_OFFSET), UNSAFE.getByte(ptr + SIZE_OFFSET)); | ||||
|                     result.put(key, measurements); | ||||
|                     result.add(new Measurement(key, measurements)); | ||||
|                 } | ||||
|             } | ||||
|             return result; | ||||
|         } | ||||
|  | ||||
|         private boolean isEqual(long address, long address2, byte size) { | ||||
|         private static boolean isEqual(long address, long address2, byte size) { | ||||
|             for (int i = 0; i < size; i++) { | ||||
|                 if (UNSAFE.getByte(address + i) != UNSAFE.getByte(address2 + i)) { | ||||
|                     return false; | ||||
| @@ -271,4 +236,69 @@ public class CalculateAverage_zerninv { | ||||
|             return new String(arr); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
|     private static class TaskThread extends Thread { | ||||
|         private final MeasurementContainer container; | ||||
|         private final List<Long> begins; | ||||
|         private final List<Long> ends; | ||||
|  | ||||
|         private TaskThread(MeasurementContainer container, int chunks) { | ||||
|             this.container = container; | ||||
|             this.begins = new ArrayList<>(chunks); | ||||
|             this.ends = new ArrayList<>(chunks); | ||||
|         } | ||||
|  | ||||
|         public void addChunk(long begin, long end) { | ||||
|             begins.add(begin); | ||||
|             ends.add(end); | ||||
|         } | ||||
|  | ||||
|         @Override | ||||
|         public void run() { | ||||
|             for (int i = 0; i < begins.size(); i++) { | ||||
|                 calcForChunk(begins.get(i), ends.get(i)); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public List<Measurement> measurements() { | ||||
|             return container.measurements(); | ||||
|         } | ||||
|  | ||||
|         private void calcForChunk(long offset, long end) { | ||||
|             long cityOffset; | ||||
|             int hashCode, temperature, word; | ||||
|             byte cityNameSize, b; | ||||
|  | ||||
|             while (offset < end) { | ||||
|                 cityOffset = offset; | ||||
|                 hashCode = 0; | ||||
|                 while ((b = UNSAFE.getByte(offset++)) != DELIMITER) { | ||||
|                     hashCode = hashCode * 31 + b; | ||||
|                 } | ||||
|                 cityNameSize = (byte) (offset - cityOffset - 1); | ||||
|  | ||||
|                 word = UNSAFE.getInt(offset); | ||||
|                 offset += 4; | ||||
|  | ||||
|                 if ((word & TWO_NEGATIVE_DIGITS_MASK) == TWO_NEGATIVE_DIGITS_MASK) { | ||||
|                     word >>>= 8; | ||||
|                     temperature = ZERO * 11 - ((word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK)); | ||||
|                 } | ||||
|                 else if ((word & THREE_DIGITS_MASK) == THREE_DIGITS_MASK) { | ||||
|                     temperature = (word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK) - ZERO * 111; | ||||
|                 } | ||||
|                 else if ((word & TWO_DIGITS_MASK) == TWO_DIGITS_MASK) { | ||||
|                     temperature = (word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK) - ZERO * 11; | ||||
|                     offset--; | ||||
|                 } | ||||
|                 else { | ||||
|                     // #.##- | ||||
|                     word = (word >>> 8) | (UNSAFE.getByte(offset++) << 24); | ||||
|                     temperature = ZERO * 111 - ((word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK)); | ||||
|                 } | ||||
|                 offset++; | ||||
|                 container.put(cityOffset, cityNameSize, hashCode, (short) temperature); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user