armandino: second attempt (#445)
This commit is contained in:
		| @@ -15,188 +15,143 @@ | ||||
|  */ | ||||
| package dev.morling.onebrc; | ||||
|  | ||||
| import sun.misc.Unsafe; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.io.PrintStream; | ||||
| import java.nio.ByteBuffer; | ||||
| import java.lang.foreign.Arena; | ||||
| import java.lang.reflect.Field; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.file.Path; | ||||
| import java.nio.file.StandardOpenOption; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collections; | ||||
| import java.util.Map; | ||||
| import java.util.concurrent.ConcurrentHashMap; | ||||
| import java.util.Arrays; | ||||
| import java.util.Collection; | ||||
| import java.util.Objects; | ||||
| import java.util.TreeMap; | ||||
| import java.util.stream.Stream; | ||||
|  | ||||
| import static java.nio.channels.FileChannel.MapMode.READ_ONLY; | ||||
| import static java.nio.charset.StandardCharsets.UTF_8; | ||||
| import static java.util.stream.Collectors.toMap; | ||||
|  | ||||
| public class CalculateAverage_armandino { | ||||
|  | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|     private static final Path FILE = Path.of("./measurements.txt"); | ||||
|  | ||||
|     private static final int MAX_KEY_LENGTH = 100; | ||||
|     private static final int NUM_CHUNKS = Math.max(8, Runtime.getRuntime().availableProcessors()); | ||||
|     private static final int INITIAL_MAP_CAPACITY = 8192; | ||||
|     private static final byte SEMICOLON = 59; | ||||
|     private static final byte NL = 10; | ||||
|     private static final byte DOT = 46; | ||||
|     private static final byte MINUS = 45; | ||||
|     private static final byte ZERO_DIGIT = 48; | ||||
|     private static final Unsafe UNSAFE = getUnsafe(); | ||||
|  | ||||
|     public static void main(String[] args) throws Exception { | ||||
|         Aggregator aggregator = new Aggregator(); | ||||
|         aggregator.process(); | ||||
|         aggregator.printStats(); | ||||
|         var channel = FileChannel.open(FILE, StandardOpenOption.READ); | ||||
|  | ||||
|         var results = Arrays.stream(split(channel)).parallel() | ||||
|                 .map(chunk -> new ChunkProcessor().process(chunk.start, chunk.end)) | ||||
|                 .flatMap(SimpleMap::stream) | ||||
|                 .collect(toMap(Stats::getKey, s -> s, CalculateAverage_armandino::mergeStats, TreeMap::new)); | ||||
|  | ||||
|         print(results.values()); | ||||
|     } | ||||
|  | ||||
|     private static class Aggregator { | ||||
|     private static Stats mergeStats(final Stats x, final Stats y) { | ||||
|         x.min = Math.min(x.min, y.min); | ||||
|         x.max = Math.max(x.max, y.max); | ||||
|         x.count += y.count; | ||||
|         x.sum += y.sum; | ||||
|         return x; | ||||
|     } | ||||
|  | ||||
|         private final Map<Integer, Stats> map = new ConcurrentHashMap<>(2048); | ||||
|     private static class ChunkProcessor { | ||||
|         private final SimpleMap map = new SimpleMap(INITIAL_MAP_CAPACITY); | ||||
|  | ||||
|         private record Chunk(long start, long end) { | ||||
|         } | ||||
|         private SimpleMap process(final long chunkStart, final long chunkEnd) { | ||||
|             long i = chunkStart; | ||||
|             while (i < chunkEnd) { | ||||
|                 final long keyAddress = i; | ||||
|                 int keyHash = 0; | ||||
|                 int measurement = 0; | ||||
|                 byte b; | ||||
|  | ||||
|         void process() throws Exception { | ||||
|             var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); | ||||
|             final Chunk[] chunks = split(channel); | ||||
|             final Thread[] threads = new Thread[chunks.length]; | ||||
|                 while ((b = UNSAFE.getByte(i++)) != SEMICOLON) { | ||||
|                     keyHash = 31 * keyHash + b; | ||||
|                 } | ||||
|  | ||||
|             for (int i = 0; i < chunks.length; i++) { | ||||
|                 final Chunk chunk = chunks[i]; | ||||
|                 final int keyLength = (int) (i - keyAddress - 1); | ||||
|  | ||||
|                 threads[i] = Thread.ofVirtual().start(() -> { | ||||
|                     try { | ||||
|                         var bb = channel.map(READ_ONLY, chunk.start, chunk.end - chunk.start); | ||||
|                         process(bb); | ||||
|                 if ((b = UNSAFE.getByte(i++)) == MINUS) { | ||||
|                     while ((b = UNSAFE.getByte(i++)) != DOT) { | ||||
|                         measurement = measurement * 10 + b - ZERO_DIGIT; | ||||
|                     } | ||||
|                     catch (IOException e) { | ||||
|                         throw new RuntimeException(e); | ||||
|                     } | ||||
|                 }); | ||||
|             } | ||||
|  | ||||
|             for (Thread t : threads) { | ||||
|                 t.join(); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private static Chunk[] split(final FileChannel channel) throws IOException { | ||||
|             final long fileSize = channel.size(); | ||||
|             if (fileSize < 10000) { | ||||
|                 return new Chunk[]{ new Chunk(0, fileSize) }; | ||||
|             } | ||||
|  | ||||
|             final int numChunks = 8; | ||||
|             final long chunkSize = fileSize / numChunks; | ||||
|             final var chunks = new Chunk[numChunks]; | ||||
|  | ||||
|             for (int i = 0; i < numChunks; i++) { | ||||
|                 long start = 0; | ||||
|                 long end = chunkSize; | ||||
|  | ||||
|                 if (i > 0) { | ||||
|                     start = chunks[i - 1].end + 1; | ||||
|                     end = Math.min(start + chunkSize, fileSize); | ||||
|                 } | ||||
|  | ||||
|                 end = end == fileSize ? end : seekNextNewline(channel, end); | ||||
|                 chunks[i] = new Chunk(start, end); | ||||
|             } | ||||
|             return chunks; | ||||
|         } | ||||
|  | ||||
|         private static long seekNextNewline(final FileChannel channel, final long end) throws IOException { | ||||
|             var bb = ByteBuffer.allocate(MAX_KEY_LENGTH); | ||||
|             channel.position(end).read(bb); | ||||
|  | ||||
|             for (int i = 0; i < bb.limit(); i++) { | ||||
|                 if (bb.get(i) == NL) { | ||||
|                     return end + i; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             throw new IllegalStateException("Couldn't find next newline"); | ||||
|         } | ||||
|  | ||||
|         private void process(final ByteBuffer bb) { | ||||
|             final var sample = new Sample(); | ||||
|             var isKey = true; | ||||
|  | ||||
|             for (long i = 0, sz = bb.limit(); i < sz; i++) { | ||||
|  | ||||
|                 final byte b = bb.get(); | ||||
|  | ||||
|                 if (b == SEMICOLON) { | ||||
|                     isKey = false; | ||||
|                 } | ||||
|                 else if (b == NL) { | ||||
|                     isKey = true; | ||||
|                     addSample(sample); | ||||
|                     sample.reset(); | ||||
|                 } | ||||
|                 else if (isKey) { | ||||
|                     sample.pushKey(b); | ||||
|                 } | ||||
|                 else if (b == DOT) { | ||||
|                     // skip | ||||
|                 } | ||||
|                 else if (b == MINUS) { | ||||
|                     sample.sign = -1; | ||||
|                     b = UNSAFE.getByte(i); | ||||
|                     measurement = measurement * 10 + b - ZERO_DIGIT; | ||||
|                     measurement = -measurement; | ||||
|                     i += 2; | ||||
|                 } | ||||
|                 else { | ||||
|                     sample.pushMeasurement(b); | ||||
|                     measurement = b - ZERO_DIGIT; // D1 | ||||
|                     b = UNSAFE.getByte(i); // dot or D2 | ||||
|  | ||||
|                     if (b == DOT) { | ||||
|                         measurement = measurement * 10 + UNSAFE.getByte(i + 1) - ZERO_DIGIT; // F | ||||
|                         i += 3; | ||||
|                     } | ||||
|                     else { | ||||
|                         measurement = measurement * 10 + b - ZERO_DIGIT; // D2 | ||||
|                         measurement = measurement * 10 + UNSAFE.getByte(i + 2) - ZERO_DIGIT; // F | ||||
|                         i += 4; // skip NL | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 final Stats stats = map.putStats(keyHash, keyAddress, keyLength); | ||||
|                 stats.min = Math.min(stats.min, measurement); | ||||
|                 stats.max = Math.max(stats.max, measurement); | ||||
|                 stats.sum += measurement; | ||||
|                 stats.count++; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         private void addSample(final Sample sample) { | ||||
|             final Stats stats = map.computeIfAbsent(sample.keyHash, | ||||
|                     k -> new Stats(new String(sample.keyBytes, 0, sample.keyLength, UTF_8))); | ||||
|  | ||||
|             final var val = sample.getMeasurement(); | ||||
|  | ||||
|             if (val < stats.min) | ||||
|                 stats.min = val; | ||||
|  | ||||
|             if (val > stats.max) | ||||
|                 stats.max = val; | ||||
|  | ||||
|             stats.sum += val; | ||||
|             stats.count++; | ||||
|         } | ||||
|  | ||||
|         void printStats() { | ||||
|             var sorted = new ArrayList<>(map.values()); | ||||
|             Collections.sort(sorted); | ||||
|  | ||||
|             int size = sorted.size(); | ||||
|  | ||||
|             System.out.print('{'); | ||||
|  | ||||
|             for (Stats stats : sorted) { | ||||
|                 stats.print(System.out); | ||||
|                 if (--size > 0) { | ||||
|                     System.out.print(", "); | ||||
|                 } | ||||
|             } | ||||
|             System.out.println('}'); | ||||
|             return map; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static class Stats implements Comparable<Stats> { | ||||
|         private final String city; | ||||
|         private String key; | ||||
|         private final byte[] keyBytes; | ||||
|         private final int keyLength; | ||||
|         private final int keyHash; | ||||
|         private int min = Integer.MAX_VALUE; | ||||
|         private int max = Integer.MIN_VALUE; | ||||
|         private long sum; | ||||
|         private int count; | ||||
|         private long sum; | ||||
|  | ||||
|         private Stats(String city) { | ||||
|             this.city = city; | ||||
|         private Stats(long keyAddress, int keyLength, int keyHash) { | ||||
|             this.keyLength = keyLength; | ||||
|             this.keyBytes = new byte[keyLength]; | ||||
|             this.keyHash = keyHash; | ||||
|  | ||||
|             for (int i = 0; i < keyLength; i++) { | ||||
|                 keyBytes[i] = UNSAFE.getByte(keyAddress++); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         String getKey() { | ||||
|             if (key == null) { | ||||
|                 key = new String(keyBytes, 0, keyLength, UTF_8); | ||||
|             } | ||||
|             return key; | ||||
|         } | ||||
|  | ||||
|         @Override | ||||
|         public int compareTo(final Stats o) { | ||||
|             return city.compareTo(o.city); | ||||
|             return getKey().compareTo(o.getKey()); | ||||
|         } | ||||
|  | ||||
|         void print(final PrintStream out) { | ||||
|             out.print(city); | ||||
|             out.print(key); | ||||
|             out.print('='); | ||||
|             out.print(round(min / 10f)); | ||||
|             out.print('/'); | ||||
| @@ -210,32 +165,148 @@ public class CalculateAverage_armandino { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static class Sample { | ||||
|         private final byte[] keyBytes = new byte[MAX_KEY_LENGTH]; | ||||
|         private int keyLength; | ||||
|         private int keyHash; | ||||
|         private int measurement; | ||||
|         private int sign = 1; | ||||
|     private static void print(final Collection<Stats> sorted) { | ||||
|         int size = sorted.size(); | ||||
|         System.out.print('{'); | ||||
|         for (Stats stats : sorted) { | ||||
|             stats.print(System.out); | ||||
|             if (--size > 0) { | ||||
|                 System.out.print(", "); | ||||
|             } | ||||
|         } | ||||
|         System.out.println('}'); | ||||
|     } | ||||
|  | ||||
|         void pushKey(byte b) { | ||||
|             keyBytes[keyLength++] = b; | ||||
|             keyHash = 31 * keyHash + b; | ||||
|     private static Chunk[] split(final FileChannel channel) throws IOException { | ||||
|         final long fileSize = channel.size(); | ||||
|         long start = channel.map(READ_ONLY, 0, fileSize, Arena.global()).address(); | ||||
|         final long endAddress = start + fileSize; | ||||
|         if (fileSize < 10000) { | ||||
|             return new Chunk[]{ new Chunk(start, endAddress) }; | ||||
|         } | ||||
|  | ||||
|         void pushMeasurement(byte b) { | ||||
|             final int i = b - '0'; | ||||
|             measurement = measurement * 10 + i; | ||||
|         final long chunkSize = fileSize / NUM_CHUNKS; | ||||
|         final var chunks = new Chunk[NUM_CHUNKS]; | ||||
|         long end = start + chunkSize; | ||||
|  | ||||
|         for (int i = 0; i < NUM_CHUNKS; i++) { | ||||
|             if (i > 0) { | ||||
|                 start = chunks[i - 1].end; | ||||
|                 end = Math.min(start + chunkSize, endAddress); | ||||
|             } | ||||
|             if (end < endAddress) { | ||||
|                 while (UNSAFE.getByte(end) != NL) { | ||||
|                     end++; | ||||
|                 } | ||||
|                 end++; | ||||
|             } | ||||
|             chunks[i] = new Chunk(start, end); | ||||
|         } | ||||
|         return chunks; | ||||
|     } | ||||
|  | ||||
|     private record Chunk(long start, long end) { | ||||
|     } | ||||
|  | ||||
|     private static Unsafe getUnsafe() { | ||||
|         try { | ||||
|             Field unsafe = Unsafe.class.getDeclaredField("theUnsafe"); | ||||
|             unsafe.setAccessible(true); | ||||
|             return (Unsafe) unsafe.get(null); | ||||
|         } | ||||
|         catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static class SimpleMap { | ||||
|         private Stats[] table; | ||||
|  | ||||
|         SimpleMap(int initialCapacity) { | ||||
|             table = new Stats[initialCapacity]; | ||||
|         } | ||||
|  | ||||
|         int getMeasurement() { | ||||
|             return sign * measurement; | ||||
|         Stream<Stats> stream() { | ||||
|             return Arrays.stream(table).filter(Objects::nonNull); | ||||
|         } | ||||
|  | ||||
|         void reset() { | ||||
|             keyHash = 0; | ||||
|             keyLength = 0; | ||||
|             measurement = 0; | ||||
|             sign = 1; | ||||
|         private void resize() { | ||||
|             var copy = new SimpleMap(table.length * 2); | ||||
|             for (Stats s : table) { | ||||
|                 if (s != null) { | ||||
|                     final int pos = (copy.table.length - 1) & s.keyHash; | ||||
|                     int i = pos; | ||||
|  | ||||
|                     if (copy.table[i] == null) { | ||||
|                         copy.table[i] = s; | ||||
|                         continue; | ||||
|                     } | ||||
|  | ||||
|                     while (i < copy.table.length && copy.table[i] != null) { | ||||
|                         i++; | ||||
|                     } | ||||
|                     if (i == copy.table.length) { | ||||
|                         i = pos; | ||||
|                         while (i >= 0 && copy.table[i] != null) { | ||||
|                             i--; | ||||
|                         } | ||||
|                     } | ||||
|                     if (i < 0) { | ||||
|                         // shouldn't happen because put() is called after increasing size | ||||
|                         throw new IllegalStateException("table is full"); | ||||
|                     } | ||||
|                     copy.table[i] = s; | ||||
|                 } | ||||
|             } | ||||
|             table = copy.table; | ||||
|         } | ||||
|  | ||||
|         Stats putStats(final int keyHash, final long keyAddress, final int keyLength) { | ||||
|             final int pos = (table.length - 1) & keyHash; | ||||
|  | ||||
|             Stats stats = table[pos]; | ||||
|             if (stats == null) | ||||
|                 return createAt(table, keyAddress, keyLength, keyHash, pos); | ||||
|             if (stats.keyHash == keyHash && keysEqual(stats, keyAddress, keyLength)) | ||||
|                 return stats; | ||||
|  | ||||
|             int i = pos; | ||||
|             while (++i < table.length) { | ||||
|                 stats = table[i]; | ||||
|                 if (stats == null) | ||||
|                     return createAt(table, keyAddress, keyLength, keyHash, i); | ||||
|                 if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength)) | ||||
|                     return stats; | ||||
|             } | ||||
|  | ||||
|             i = pos; | ||||
|             while (i-- > 0) { | ||||
|                 stats = table[i]; | ||||
|                 if (stats == null) | ||||
|                     return createAt(table, keyAddress, keyLength, keyHash, i); | ||||
|                 if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength)) | ||||
|                     return stats; | ||||
|             } | ||||
|             resize(); | ||||
|             return putStats(keyHash, keyAddress, keyLength); | ||||
|         } | ||||
|  | ||||
|         private boolean keysEqual(Stats stats, long keyAddress, final int keyLength) { | ||||
|             if (stats.keyLength != keyLength) { | ||||
|                 return false; | ||||
|             } | ||||
|             for (int i = 0; i < keyLength; i++) { | ||||
|                 if (stats.keyBytes[i] != UNSAFE.getByte(keyAddress++)) { | ||||
|                     return false; | ||||
|                 } | ||||
|             } | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|         private static Stats createAt(Stats[] table, long keyAddress, int keyLength, int key, int i) { | ||||
|             Stats stats = new Stats(keyAddress, keyLength, key); | ||||
|             table[i] = stats; | ||||
|             return stats; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user