/* * Copyright 2023 The original authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package dev.morling.onebrc; import java.io.IOException; import java.lang.foreign.Arena; import java.lang.reflect.Field; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; import sun.misc.Unsafe; /** * Changelog: * * Initial submission: 62000 ms * Chunked reader: 16000 ms * Optimized parser: 13000 ms * Branchless methods: 11000 ms * Adding memory mapped files: 6500 ms (based on bjhara's submission) * Skipping string creation: 4700 ms * Custom hashmap... 4200 ms * Added SWAR token checks: 3900 ms * Skipped String creation: 3500 ms (idea from kgonia) * Improved String skip: 3250 ms * Segmenting files: 3150 ms (based on spullara's code) * Not using SWAR for EOL: 2850 ms * Inlining hash calculation: 2450 ms * Replacing branchless code: 2200 ms (sometimes we need to kill the things we love) * Added unsafe memory access: 1900 ms (keeping the long[] small and local) * Fixed bug, UNSAFE bytes String: 1850 ms * Separate hash from entries: 1550 ms * Various tweaks for Linux/cache 1550 ms (should/could make a difference on target machine) * Improved layout/predictability 1450 ms (on par with Thomas Wuerthinger) * * Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai for ideas. */ public class CalculateAverage_royvanrijn { private static final String FILE = "./measurements.txt"; private static final Unsafe UNSAFE = initUnsafe(); private static Unsafe initUnsafe() { try { final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); theUnsafe.setAccessible(true); return (Unsafe) theUnsafe.get(Unsafe.class); } catch (NoSuchFieldException | IllegalAccessException e) { throw new RuntimeException(e); } } public static void main(String[] args) throws Exception { // Calculate input segments. final int numberOfChunks = Runtime.getRuntime().availableProcessors(); final long[] chunks = getSegments(numberOfChunks); final List repositories = IntStream.range(0, chunks.length - 1) .mapToObj(chunkIndex -> processMemoryArea(chunks[chunkIndex], chunks[chunkIndex + 1])) .parallel() .toList(); // Sometimes simple is better: final HashMap measurements = HashMap.newHashMap(1 << 10); for (Entry[] entries : repositories) { for (Entry entry : entries) { if (entry != null) measurements.merge(entry.city, entry, Entry::mergeWith); } } System.out.print("{" + measurements.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", "))); System.out.println("}"); } /** * Simpler way to get the segments and launch parallel processing by thomaswue */ private static long[] getSegments(final int numberOfChunks) throws IOException { try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { final long fileSize = fileChannel.size(); final long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks; final long[] chunks = new long[numberOfChunks + 1]; final long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); chunks[0] = mappedAddress; final long endAddress = mappedAddress + fileSize; for (int i = 1; i < numberOfChunks; ++i) { long chunkAddress = mappedAddress + i * segmentSize; // Align to first row start. while (chunkAddress < endAddress && UNSAFE.getByte(chunkAddress++) != '\n') { // nop } chunks[i] = Math.min(chunkAddress, endAddress); } chunks[numberOfChunks] = endAddress; return chunks; } } private static final int TABLE_SIZE = 1 << 18; // large enough for the contest. private static final int TABLE_MASK = (TABLE_SIZE - 1); static final class Entry { private final long[] data; private final String city; private int min, max, count; private long sum; Entry(final long[] data, String city, int temp) { this.data = data; this.city = city; this.min = temp; this.max = temp; this.sum = temp; this.count = 1; } public void updateWith(int measurement) { min = Math.min(min, measurement); max = Math.max(max, measurement); sum += measurement; count++; } public Entry mergeWith(Entry entry) { min = Math.min(min, entry.min); max = Math.max(max, entry.max); sum += entry.sum; count += entry.count; return this; } public String toString() { return round(min) + "/" + round((1.0 * sum) / count) + "/" + round(max); } private static double round(double value) { return Math.round(value) / 10.0; } } private static Entry createNewEntry(final long[] buffer, final long startAddress, final int lengthLongs, final int lengthBytes, final int temp) { // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here. final byte[] bytes = new byte[lengthBytes]; UNSAFE.copyMemory(null, startAddress, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, lengthBytes); final String city = new String(bytes, StandardCharsets.UTF_8); final long[] bufferCopy = new long[lengthLongs]; System.arraycopy(buffer, 0, bufferCopy, 0, lengthLongs); // Add the entry: return new Entry(bufferCopy, city, temp); } private static Entry[] processMemoryArea(final long fromAddress, final long toAddress) { Entry[] table = new Entry[TABLE_SIZE]; long ptr = fromAddress; long[] buffer = new long[14]; while (ptr < toAddress) { int bufferPtr = 0; long startAddress = ptr; long hash = 1; long word = UNSAFE.getLong(ptr); long mask = getDelimiterMask(word); while (mask == 0) { buffer[bufferPtr++] = word; hash ^= word; ptr += 8; word = UNSAFE.getLong(ptr); mask = getDelimiterMask(word); } // Found delimiter: final long delimiterAddress = ptr + (Long.numberOfTrailingZeros(mask) >> 3); final long numberBits = UNSAFE.getLong(delimiterAddress + 1); // Finish the masks and hash: final long partialWord = word & ((mask >> 7) - 1); buffer[bufferPtr++] = partialWord; hash ^= partialWord; final long invNumberBits = ~numberBits; final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBits & DOT_BITS); // Update counter asap, lets CPU predict. ptr = delimiterAddress + (decimalSepPos >> 3) + 4; int intHash = (int) (hash ^ (hash >>> 31)); // offset for extra entropy // Awesome idea of merykitty: final int temp = extractTemp(numberBits, invNumberBits, decimalSepPos); int index = intHash & TABLE_MASK; // Find or insert the entry: while (true) { Entry tableEntry = table[index]; if (tableEntry == null) { final int length = (int) (delimiterAddress - startAddress); table[index] = createNewEntry(buffer, startAddress, bufferPtr, length, temp); break; } else if (bufferPtr == tableEntry.data.length) { if (!arrayEquals(buffer, tableEntry.data, bufferPtr)) { index = (index + 1) & TABLE_MASK; continue; } // No differences in array tableEntry.updateWith(temp); break; } // Move to the next index index = (index + 1) & TABLE_MASK; } } return table; } private static int extractTemp(final long numberBits, final long invNumberBits, final int decimalSepPos) { final long signed = (invNumberBits << 59) >> 63; final long minusFilter = ~(signed & 0xFF); final long digits = ((numberBits & minusFilter) << (28 - decimalSepPos)) & 0x0F000F0F00L; final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; // filter just the result final int temp = (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick return temp; } private static long getDelimiterMask(final long word) { long match = word ^ SEPARATOR_PATTERN; return (match - 0x0101010101010101L) & ~match & 0x8080808080808080L; } private static final long SEPARATOR_PATTERN = compilePattern((byte) ';'); private static final long DOT_BITS = 0x10101000; private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); private static long compilePattern(final byte value) { return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) | ((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value; } /** * For case multiple hashes are equal (however unlikely) check the actual key (using longs) */ static boolean arrayEquals(final long[] a, final long[] b, final int length) { for (int i = 0; i < length; i++) { if (a[i] != b[i]) return false; } return true; } }