/*
 *  Copyright 2023 The original authors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package dev.morling.onebrc;

import java.lang.foreign.Arena;
import java.lang.reflect.Field;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

import sun.misc.Unsafe;

/**
 * Changelog:
 *
 * Initial submission:               62000 ms
 * Chunked reader:                   16000 ms
 * Optimized parser:                 13000 ms
 * Branchless methods:               11000 ms
 * Adding memory mapped files:       6500 ms (based on bjhara's submission)
 * Skipping string creation:         4700 ms
 * Custom hashmap...                 4200 ms
 * Added SWAR token checks:          3900 ms
 * Skipped String creation:          3500 ms (idea from kgonia)
 * Improved String skip:             3250 ms
 * Segmenting files:                 3150 ms (based on spullara's code)
 * Not using SWAR for EOL:           2850 ms
 * Inlining hash calculation:        2450 ms
 * Replacing branchless code:        2200 ms (sometimes we need to kill the things we love)
 * Added unsafe memory access:       1900 ms (keeping the long[] small and local)
 * Fixed bug, UNSAFE bytes String:   1850 ms
 * Separate hash from entries:       1550 ms
 * Various tweaks for Linux/cache    1550 ms (should/could make a difference on target machine)
 * Improved layout/predictability:   1400 ms
 * Delayed String creation again:    1350 ms
 * Remove writing to buffer:         1335 ms
 * Optimized collecting at the end:  1310 ms
 * Adding a lot of comments:         priceless
 * Changed to flyweight byte[]:      1290 ms (adds even more Unsafe, was initially slower, now faster)
 * More LOC now parallel:            1260 ms (moved more to processMemoryArea, recombining in ConcurrentHashMap)
 * Storing only the address:         1240 ms (this is now faster, tried before, was slower)
 * Unrolling scan-loop:              1200 ms (seems to help, perhaps even more on target machine)
 * Adding more readable reader:      1300 ms (scores got worse on target machine anyway)
 *
 * I've ditched my M2 for an older x86-64 MacBook, this allows me to run `perf` and I'm trying to get lower numbers by trail and error.
 *
 * Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai and many others for ideas.
 *
 * Follow me at: @royvanrijn
 */
public class CalculateAverage_royvanrijn {

    private static final String FILE = "./measurements.txt";
    // private static final String FILE = "src/test/resources/samples/measurements-1.txt";

    private static final Unsafe UNSAFE = initUnsafe();

    // Twice the processors, smoothens things out.
    private static final int PROCESSORS = Runtime.getRuntime().availableProcessors();

    /**
     * Flyweight entry in a byte[], max 128 bytes.
     *
     * long: sum
     * int:  min
     * int:  max
     * int:  count
     * byte: length
     * byte[]: cityname
     */
    // ------------------------------------------------------------------------
    private static final int ENTRY_LENGTH = (Unsafe.ARRAY_BYTE_BASE_OFFSET);
    private static final int ENTRY_SUM = (ENTRY_LENGTH + Byte.BYTES);
    private static final int ENTRY_MIN = (ENTRY_SUM + Long.BYTES);
    private static final int ENTRY_MAX = (ENTRY_MIN + Integer.BYTES);
    private static final int ENTRY_COUNT = (ENTRY_MAX + Integer.BYTES);
    private static final int ENTRY_NAME = (ENTRY_COUNT + Integer.BYTES);
    private static final int ENTRY_NAME_8 = ENTRY_NAME + 8;
    private static final int ENTRY_NAME_16 = ENTRY_NAME + 16;

    private static final int ENTRY_BASESIZE_WHITESPACE = ENTRY_NAME + 7; // with enough empty bytes to fill a long
    // ------------------------------------------------------------------------
    private static final int PREMADE_MAX_SIZE = 1 << 5; // pre-initialize some entries in memory, keep them close
    private static final int PREMADE_ENTRIES = 512; // amount of pre-created entries we should use
    private static final int TABLE_SIZE = 1 << 19; // large enough for the contest.
    private static final int TABLE_MASK = (TABLE_SIZE - 1);

    public static void main(String[] args) throws Exception {

        // Calculate input segments.
        final FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
        final long fileSize = fileChannel.size();
        final long segmentSize = (fileSize + PROCESSORS - 1) / PROCESSORS;
        final long mapAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();

        final Thread[] parallelThreads = new Thread[PROCESSORS - 1];

        // This is where the entries will land:
        final ConcurrentHashMap<String, byte[]> measurements = new ConcurrentHashMap(1 << 10);

        // We create separate threads for twice the amount of processors.
        long lastAddress = mapAddress;
        final long endOfFile = mapAddress + fileSize;
        for (int i = 0; i < PROCESSORS - 1; ++i) {

            final long fromAddress = lastAddress;
            final long toAddress = Math.min(endOfFile, fromAddress + segmentSize);

            final Thread thread = new Thread(() -> {
                // The actual work is done here:
                final byte[][] table = processMemoryArea(fromAddress, toAddress, fromAddress == mapAddress);

                for (byte[] entry : table) {
                    if (entry != null) {
                        measurements.merge(entryToName(entry), entry, CalculateAverage_royvanrijn::mergeEntry);
                    }
                }
            });
            thread.start(); // start a.s.a.p.
            parallelThreads[i] = thread;
            lastAddress = toAddress;
        }

        // Use the current thread for the part of memory:
        final byte[][] table = processMemoryArea(lastAddress, mapAddress + fileSize, false);

        for (byte[] entry : table) {
            if (entry != null) {
                measurements.merge(entryToName(entry), entry, CalculateAverage_royvanrijn::mergeEntry);
            }
        }
        // Wait for all threads to finish:
        for (Thread thread : parallelThreads) {
            // Can we implement work-stealing? Not sure how...
            thread.join();
        }

        // If we don't reach start of file,
        System.out.print("{" +
                measurements.entrySet().stream().sorted(Map.Entry.comparingByKey())
                        .map(entry -> entry.getKey() + '=' + entryValuesToString(entry.getValue()))
                        .collect(Collectors.joining(", ")));
        System.out.println("}");

        // System.out.println(measurements.entrySet().stream().mapToLong(e -> UNSAFE.getInt(e.getValue(), ENTRY_COUNT + Unsafe.ARRAY_BYTE_BASE_OFFSET)).sum());
    }

    private static byte[] fillEntry(final byte[] entry, final long fromAddress, final int length, final int temp) {
        UNSAFE.putLong(entry, ENTRY_SUM, temp);
        UNSAFE.putInt(entry, ENTRY_MIN, temp);
        UNSAFE.putInt(entry, ENTRY_MAX, temp);
        UNSAFE.putInt(entry, ENTRY_COUNT, 1);
        UNSAFE.putByte(entry, ENTRY_LENGTH, (byte) length);
        UNSAFE.copyMemory(null, fromAddress, entry, ENTRY_NAME, length);
        return entry;
    }

    public static void updateEntry(final byte[] entry, final int temp) {

        int entryMin = UNSAFE.getInt(entry, ENTRY_MIN);
        int entryMax = UNSAFE.getInt(entry, ENTRY_MAX);

        entryMin = Math.min(temp, entryMin);
        entryMax = Math.max(temp, entryMax);

        long entrySum = UNSAFE.getLong(entry, ENTRY_SUM) + temp;
        int entryCount = UNSAFE.getInt(entry, ENTRY_COUNT) + 1;

        UNSAFE.putInt(entry, ENTRY_MIN, entryMin);
        UNSAFE.putInt(entry, ENTRY_MAX, entryMax);
        UNSAFE.putInt(entry, ENTRY_COUNT, entryCount);
        UNSAFE.putLong(entry, ENTRY_SUM, entrySum);
    }

    public static byte[] mergeEntry(final byte[] entry, final byte[] merge) {

        long sum = UNSAFE.getLong(merge, ENTRY_SUM);
        final int mergeMin = UNSAFE.getInt(merge, ENTRY_MIN);
        final int mergeMax = UNSAFE.getInt(merge, ENTRY_MAX);
        int count = UNSAFE.getInt(merge, ENTRY_COUNT);

        sum += UNSAFE.getLong(entry, ENTRY_SUM);
        int entryMin = UNSAFE.getInt(entry, ENTRY_MIN);
        int entryMax = UNSAFE.getInt(entry, ENTRY_MAX);
        count += UNSAFE.getInt(entry, ENTRY_COUNT);

        entryMin = Math.min(entryMin, mergeMin);
        entryMax = Math.max(entryMax, mergeMax);

        UNSAFE.putLong(entry, ENTRY_SUM, sum);
        UNSAFE.putInt(entry, ENTRY_MIN, entryMin);
        UNSAFE.putInt(entry, ENTRY_MAX, entryMax);
        UNSAFE.putInt(entry, ENTRY_COUNT, count);
        return entry;
    }

    private static String entryToName(final byte[] entry) {
        // Get the length from memory:
        int length = UNSAFE.getByte(entry, ENTRY_LENGTH);

        byte[] name = new byte[length];
        UNSAFE.copyMemory(entry, ENTRY_NAME, name, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);

        // Create a new String with the existing byte[]:
        return new String(name, StandardCharsets.UTF_8);
    }

    private static String entryValuesToString(final byte[] entry) {
        return round(UNSAFE.getInt(entry, ENTRY_MIN))
                + "/" +
                round((1.0 * UNSAFE.getLong(entry, ENTRY_SUM)) /
                        UNSAFE.getInt(entry, ENTRY_COUNT))
                + "/" +
                round(UNSAFE.getInt(entry, ENTRY_MAX));
    }

    // Print a piece of memory:
    // For debug.
    private static String printMemory(final Object target, final long address, int length) {
        String result = "";
        for (int i = 0; i < length; i++) {
            result += (char) UNSAFE.getByte(target, address + i);
        }
        return result;
    }

    // Print a piece of memory:
    // For debug.
    private static String printMemory(final long value, int length) {
        String result = "";
        for (int i = 0; i < length; i++) {
            result += (char) ((value >> (i << 3)) & 0xFF);
        }
        return result;
    }

    private static double round(final double value) {
        return Math.round(value) / 10.0;
    }

    private static final class Reader {

        private long ptr;
        private long delimiterMask;
        private long lastRead;
        private long lastReadMinOne;

        private long hash;
        private long entryStart;
        private long entryDelimiter;

        private final long endAddress;

        Reader(final long startAddress, final long endAddress, final boolean isFileStart) {

            this.ptr = startAddress;
            this.endAddress = endAddress;

            // Adjust start to next delimiter:
            if (!isFileStart) {
                ptr--;
                while (ptr < endAddress) {
                    if (UNSAFE.getByte(ptr++) == '\n') {
                        break;
                    }
                }
            }
        }

        private void processStart() {
            hash = 0;
            entryStart = ptr;
        }

        private boolean hasNext() {
            return (ptr < endAddress);
        }

        private static final long DELIMITER_MASK = 0x3B3B3B3B3B3B3B3BL;

        private boolean readFirst() {
            lastRead = UNSAFE.getLong(ptr);

            final long match = lastRead ^ DELIMITER_MASK;
            delimiterMask = (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);

            return delimiterMask == 0;
        }

        private boolean readNext() {
            lastReadMinOne = lastRead;
            return readFirst();
        }

        private void processName() {
            hash ^= lastRead;
            ptr += 8;
        }

        private int processEndAndGetTemperature() {
            processFinalBytes();

            finalizeHash();
            finalizeDelimiter();

            return readTemperature();
        }

        private void processFinalBytes() {
            // Shift and read the last bytes:
            lastRead &= ((delimiterMask >>> 7) - 1);
        }

        private void finalizeHash() {
            // Finalize hash:
            hash ^= lastRead;
            hash ^= hash >> 32;
            hash ^= hash >> 17; // extra entropy
        }

        private void finalizeDelimiter() {
            // Found delimiter:
            entryDelimiter = ptr + (Long.numberOfTrailingZeros(delimiterMask) >> 3);
        }

        private static final long DOT_BITS = 0x10101000;
        private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);

        // Awesome idea of merykitty:
        private int readTemperature() {
            // This is the number part: X.X, -X.X, XX.x or -XX.X
            long numberBytes = UNSAFE.getLong(entryDelimiter + 1);
            long invNumberBytes = ~numberBytes;

            int dotPosition = Long.numberOfTrailingZeros(invNumberBytes & DOT_BITS);

            // Update the pointer here, bit awkward, but we have all the data
            ptr = entryDelimiter + (dotPosition >> 3) + 4;

            int min28 = (28 - dotPosition);
            // Calculates the sign
            final long signed = (invNumberBytes << 59) >> 63;
            final long minusFilter = ~(signed & 0xFF);
            // Use the pre-calculated decimal position to adjust the values
            long digits = ((numberBytes & minusFilter) << min28) & 0x0F000F0F00L;
            // Multiply by a magic (100 * 0x1000000 + 10 * 0x10000 + 1), to get the result
            final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
            // And perform abs()
            return (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick
        }

        private boolean matchesEntryFull(final byte[] entry) {
            int longs = (int) (entryDelimiter - entryStart) >> 3;
            int step = 0;
            for (int i = 0; i < longs - 2; i++) {
                if (UNSAFE.getLong(entryStart + step) != UNSAFE.getLong(entry, ENTRY_NAME + step)) {
                    return false;
                }
                step += 8;
            }
            if (lastReadMinOne != UNSAFE.getLong(entry, (ENTRY_NAME_8) + step)) {
                return false;
            }
            if (lastRead != UNSAFE.getLong(entry, (ENTRY_NAME_16) + step)) {
                return false;
            }
            return true;

        }

        private boolean matchesEntryMedium(final byte[] entry) {
            if (UNSAFE.getLong(entryStart) != UNSAFE.getLong(entry, ENTRY_NAME)) {
                return false;
            }
            if (lastReadMinOne != UNSAFE.getLong(entry, ENTRY_NAME_8)) {
                return false;
            }
            if (lastRead != UNSAFE.getLong(entry, ENTRY_NAME_16)) {
                return false;
            }
            return true;
        }

        private boolean matchesEntryShort(final byte[] entry) {
            if (lastReadMinOne != UNSAFE.getLong(entry, ENTRY_NAME)) {
                return false;
            }
            if (lastRead != UNSAFE.getLong(entry, ENTRY_NAME_8)) {
                return false;
            }
            return true;
        }

        private boolean matchesEnding(final byte[] entry) {
            return lastRead == UNSAFE.getLong(entry, ENTRY_NAME);
        }

        private int length() {
            return (int) (entryDelimiter - entryStart);

        }

    }

    private static byte[][] processMemoryArea(final long startAddress, final long endAddress, boolean isFileStart) {

        final byte[][] table = new byte[TABLE_SIZE][];
        final byte[][] preConstructedEntries = new byte[PREMADE_ENTRIES][ENTRY_BASESIZE_WHITESPACE + PREMADE_MAX_SIZE];

        final Reader reader = new Reader(startAddress, endAddress, isFileStart);

        byte[] entry;
        int entryCount = 0;

        // Find the correct starting position
        while (reader.hasNext()) {

            reader.processStart();

            if (!reader.readFirst()) {
                int temperature = reader.processEndAndGetTemperature();

                // Find or insert the entry:
                int index = (int) (reader.hash & TABLE_MASK);
                while (true) {
                    entry = table[index];
                    if (entry == null) {
                        int length = reader.length();
                        byte[] entryBytes = (entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++]
                                : new byte[ENTRY_BASESIZE_WHITESPACE + length];
                        table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature);
                        break;
                    }
                    else if (reader.matchesEnding(entry)) {
                        updateEntry(entry, temperature);
                        break;
                    }
                    else {
                        // Move to the next index
                        index = (index + 1) & TABLE_MASK;
                    }
                }
            }
            else {
                reader.processName();

                if (!reader.readNext()) {

                    int temperature = reader.processEndAndGetTemperature();

                    // Find or insert the entry:
                    int index = (int) (reader.hash & TABLE_MASK);
                    while (true) {
                        entry = table[index];
                        if (entry == null) {
                            int length = reader.length();
                            byte[] entryBytes = (entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++]
                                    : new byte[ENTRY_BASESIZE_WHITESPACE + length];
                            table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature);
                            break;
                        }
                        else if (reader.matchesEntryShort(entry)) {
                            updateEntry(entry, temperature);
                            break;
                        }
                        else {
                            // Move to the next index
                            index = (index + 1) & TABLE_MASK;
                        }
                    }
                }
                else {
                    reader.processName();

                    if (!reader.readNext()) {
                        int temperature = reader.processEndAndGetTemperature();

                        // Find or insert the entry:
                        int index = (int) (reader.hash & TABLE_MASK);
                        while (true) {
                            entry = table[index];
                            if (entry == null) {
                                int length = reader.length();
                                byte[] entryBytes = (entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++]
                                        : new byte[ENTRY_BASESIZE_WHITESPACE + length];
                                table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature);
                                break;
                            }
                            else if (reader.matchesEntryMedium(entry)) {
                                updateEntry(entry, temperature);
                                break;
                            }
                            else {
                                // Move to the next index
                                index = (index + 1) & TABLE_MASK;
                            }
                        }

                    }
                    else {

                        reader.processName();
                        while (reader.readNext()) {
                            reader.processName();
                        }

                        int temperature = reader.processEndAndGetTemperature();

                        // Find or insert the entry:
                        int index = (int) (reader.hash & TABLE_MASK);
                        while (true) {
                            entry = table[index];
                            if (entry == null) {
                                int length = reader.length();
                                byte[] entryBytes = (length < PREMADE_MAX_SIZE && entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++]
                                        : new byte[ENTRY_BASESIZE_WHITESPACE + length]; // with enough room
                                table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature);
                                break;
                            }
                            else if (reader.matchesEntryFull(entry)) {
                                updateEntry(entry, temperature);
                                break;
                            }
                            else {
                                // Move to the next index
                                index = (index + 1) & TABLE_MASK;
                            }
                        }
                    }
                }
            }

        }
        return table;
    }

    /*
     * `___` ___ ___ _ ___` ` ___ ` _ ` _ ` _` ___
     * / ` \| _ \ __| \| \ \ / /_\ | | | | | | __|
     * | () | _ / __|| . |\ V / _ \| |_| |_| | ._|
     * \___/|_| |___|_|\_| \_/_/ \_\___|\___/|___|
     * ---------------- BETTER SOFTWARE, FASTER --
     *
     * https://www.openvalue.eu/
     */

    private static Unsafe initUnsafe() {
        try {
            final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
            theUnsafe.setAccessible(true);
            return (Unsafe) theUnsafe.get(Unsafe.class);
        }
        catch (NoSuchFieldException | IllegalAccessException e) {
            throw new RuntimeException(e);
        }
    }
}