From 84f6331b835dd2f1c74c49ca9a01c63e87779fda Mon Sep 17 00:00:00 2001 From: Florin Blanaru Date: Sat, 27 Jan 2024 16:20:02 +0200 Subject: [PATCH] 1BRC gigiblender (#595) * Dirty implementation gigiblender * Final impl gigiblender --- calculate_average_gigiblender.sh | 19 + .../onebrc/CalculateAverage_gigiblender.java | 501 ++++++++++++++++++ 2 files changed, 520 insertions(+) create mode 100755 calculate_average_gigiblender.sh create mode 100644 src/main/java/dev/morling/onebrc/CalculateAverage_gigiblender.java diff --git a/calculate_average_gigiblender.sh b/calculate_average_gigiblender.sh new file mode 100755 index 0000000..7d51bdc --- /dev/null +++ b/calculate_average_gigiblender.sh @@ -0,0 +1,19 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +JAVA_OPTS="--enable-preview" +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_gigiblender diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_gigiblender.java b/src/main/java/dev/morling/onebrc/CalculateAverage_gigiblender.java new file mode 100644 index 0000000..162d712 --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_gigiblender.java @@ -0,0 +1,501 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import sun.misc.Unsafe; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.reflect.Field; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.TreeMap; + +public class CalculateAverage_gigiblender { + private static final int AVAIL_CORES = Runtime.getRuntime().availableProcessors(); + private static final HashTable[] tables = new HashTable[AVAIL_CORES]; + + private static Unsafe unsafe; + static { + Field theUnsafe = null; + try { + theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + unsafe = (Unsafe) theUnsafe.get(Unsafe.class); + } + catch (IllegalAccessException | NoSuchFieldException ignored) { + } + } + + private static final String FILE = "./measurements.txt"; + + static class HashTable { + + // 10_000 unique hashes -> + private static final int ENTRY_SIZE = 32; + private static final int NUM_ENTRIES = 16384; + private static final int DATA_SIZE = NUM_ENTRIES * ENTRY_SIZE; + + /* + * data[i -> i + 7] = 8 bytes hash + * data[i + 8 -> i + 15] = 7 bytes masked address of the string in the file. 1 byte for the length of the string + * data[i + 16 -> i + 19] = 4 bytes count + * data[i + 20 -> i + 21] = 2 bytes max + * data[i + 22 -> i + 23] = 2 bytes min -- sign preserved + * data[i + 24 -> i + 31] = 8 bytes sum + */ + byte[] data; + + private static final int HASH_OFFSET = 0; + + private static final int ADDR_OFFSET = 8; + private static final long ADDR_MASK = 0x00FFFFFFFFFFFFFFL; + private static final int STRING_LENGTH_SHIFT = 56; + + private static final int COUNT_OFFSET = 16; + + private static final int SUM_OFFSET = 24; + + private int reprobe_count; + + public HashTable() { + data = new byte[DATA_SIZE]; + // reprobe_count = 0; + } + + private long string_addr_and_length(long hash) { + return unsafe.getLong(data, Unsafe.ARRAY_BYTE_BASE_OFFSET + hash + ADDR_OFFSET); + } + + private static long string_addr(long encoded_str_addr) { + return (encoded_str_addr & ADDR_MASK); + } + + private static long string_length(long encoded_str_addr) { + return encoded_str_addr >>> STRING_LENGTH_SHIFT; + } + + private long count_max_min(long hash) { + return unsafe.getLong(data, Unsafe.ARRAY_BYTE_BASE_OFFSET + hash + COUNT_OFFSET); + } + + private static short mask_min(long count_max_min) { + // Preserve the sign + return (short) (count_max_min >> 6 * Byte.SIZE); + } + + private static short mask_max(long count_max_min) { + return (short) (count_max_min >>> 4 * Byte.SIZE); + } + + private static int mask_count(long count_max_min) { + return (int) count_max_min; + } + + private static long encode_count_max_min(int count, short max, short min) { + return ((long) count) | ((((long) max) & 0xFFFF) << 4 * Byte.SIZE) | (((long) min) << 6 * Byte.SIZE); + } + + private long sum(long hash) { + return unsafe.getLong(data, Unsafe.ARRAY_BYTE_BASE_OFFSET + hash + SUM_OFFSET); + } + + private static boolean string_equals(long string_addr, long entry_string_addr, int size_bytes) { + int remaining_bytes = size_bytes % 8; + int i = 0; + for (; i < size_bytes - remaining_bytes; i += 8) { + long entry_bytes = unsafe.getLong(entry_string_addr + i); + long string_bytes = unsafe.getLong(string_addr + i); + if (entry_bytes != string_bytes) { + return false; + } + } + // The hash function is not great, so I end up in this case a lot, so I take some risks. + // This never caused a SIGSEGV even though it might :) If it does, fall back to the commented version below. + // I will try to improve on the hash function + if (remaining_bytes != 0) { + long entry_bytes = unsafe.getLong(entry_string_addr + i); + long string_bytes = unsafe.getLong(string_addr + i); + // mask the bytes we care about + long mask = (1L << (remaining_bytes * Byte.SIZE)) - 1; + entry_bytes &= mask; + string_bytes &= mask; + return entry_bytes == string_bytes; + } + // for (; i < size_bytes; i++) { + // byte entry_byte = unsafe.getByte(entry_string_addr + i); + // byte string_byte = unsafe.getByte(string_addr + i); + // if (entry_byte != string_byte) { + // return false; + // } + // } + return true; + } + + public void insert(long hash, long string_addr, byte string_size, long final_number) { + assert string_addr >>> 56 == 0 : String.format("Expected final 8 bytes to be 0, got %s", Long.toBinaryString(string_addr)); + + long encoded_string_addr_and_length = string_addr | ((long) string_size << STRING_LENGTH_SHIFT); + assert string_addr(encoded_string_addr_and_length) == string_addr : String.format("Expected string addr to be %s, got %s", Long.toHexString(string_addr), + Long.toHexString(string_addr(encoded_string_addr_and_length))); + assert string_length(encoded_string_addr_and_length) == string_size + : String.format("Expected string length to be %s, got %s", string_size, string_length(encoded_string_addr_and_length)); + + long map_entry = apply_mask(hash * ENTRY_SIZE); + while (true) { + int entry_count0 = unsafe.getInt(data, Unsafe.ARRAY_BYTE_BASE_OFFSET + map_entry + COUNT_OFFSET); + if (entry_count0 == 0) { + // dump_insert(map_entry, hash, string_addr, string_size, final_number); + // Found an empty slot. Insert the entry here + unsafe.putLong(data, Unsafe.ARRAY_BYTE_BASE_OFFSET + map_entry + HASH_OFFSET, hash); + unsafe.putLong(data, Unsafe.ARRAY_BYTE_BASE_OFFSET + map_entry + ADDR_OFFSET, encoded_string_addr_and_length); + unsafe.putLong(data, Unsafe.ARRAY_BYTE_BASE_OFFSET + map_entry + COUNT_OFFSET, encode_count_max_min(1, (short) final_number, (short) final_number)); + unsafe.putLong(data, Unsafe.ARRAY_BYTE_BASE_OFFSET + map_entry + SUM_OFFSET, final_number); + + assert mask_count(encode_count_max_min(1, (short) final_number, (short) final_number)) == 1 : String.format("Expected count to be 1, got %s", + Integer.toBinaryString(mask_count(encode_count_max_min(1, (short) final_number, (short) final_number)))); + assert mask_max(encode_count_max_min(1, (short) final_number, (short) final_number)) == (short) final_number + : String.format("Expected max to be %s, got %s", final_number, + Integer.toBinaryString(mask_max(encode_count_max_min(1, (short) final_number, (short) final_number)))); + assert mask_min(encode_count_max_min(1, (short) final_number, (short) final_number)) == (short) final_number + : String.format("Expected min to be %s, got %s", final_number, + Integer.toBinaryString(mask_min(encode_count_max_min(1, (short) final_number, (short) final_number)))); + return; + } + else { + // Check if strings match. If yes, update. Otherwise, look for the next available slot + long entry_string_addr_and_length = string_addr_and_length(map_entry); + long entry_str_size = string_length(entry_string_addr_and_length); + + if (string_size != entry_str_size) { + // Strings are not the same size. Continue looking for the next slot + map_entry = apply_mask(map_entry + ENTRY_SIZE); + // reprobe_count++; + } + else { + long entry_string_addr = string_addr(entry_string_addr_and_length); + if (string_equals(string_addr, entry_string_addr, string_size)) { + // Strings are the same. Update the entry + long entry_count_max_min = count_max_min(map_entry); + int entry_count = mask_count(entry_count_max_min); + short entry_max = mask_max(entry_count_max_min); + short entry_min = mask_min(entry_count_max_min); + + entry_count++; + assert (int) final_number == final_number : String.format("Expected final number to be an int, got %s", final_number); + entry_max = (short) Math.max(entry_max, (int) final_number); + entry_min = (short) Math.min(entry_min, (int) final_number); + + long entry_sum = sum(map_entry); + entry_sum += final_number; + + unsafe.putLong(data, Unsafe.ARRAY_BYTE_BASE_OFFSET + map_entry + COUNT_OFFSET, encode_count_max_min(entry_count, entry_max, entry_min)); + unsafe.putLong(data, Unsafe.ARRAY_BYTE_BASE_OFFSET + map_entry + SUM_OFFSET, entry_sum); + return; + } + else { + // Strings are not the same. Continue looking for the next slot + map_entry = apply_mask(map_entry + ENTRY_SIZE); + // reprobe_count++; + } + } + } + } + } + + private static long apply_mask(long hash) { + return hash & (DATA_SIZE - 1); + } + + public void update_res(TreeMap result_map) { + // System.err.println("Reprobe count: " + reprobe_count); + Result r = new Result(); + + for (int i = 0; i < NUM_ENTRIES; i++) { + long entry_addr_offset = (long) i * ENTRY_SIZE; + long entry_count_max_min = count_max_min(entry_addr_offset); + int entry_count = mask_count(entry_count_max_min); + if (entry_count == 0) { + continue; + } + long entry_string_addr_and_length = string_addr_and_length(entry_addr_offset); + long entry_string_addr = string_addr(entry_string_addr_and_length); + long entry_string_length = string_length(entry_string_addr_and_length); + + // no reason to copy the byte array twice here but what can you do... + byte[] bytes = new byte[(int) entry_string_length]; + unsafe.copyMemory(null, entry_string_addr, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, entry_string_length); + String s = new String(bytes, StandardCharsets.UTF_8); + + short entry_max = mask_max(entry_count_max_min); + short entry_min = mask_min(entry_count_max_min); + + long entry_sum = sum(entry_addr_offset); + + Result ret = result_map.putIfAbsent(s, r); + if (ret == null) { + r.count = entry_count; + r.max = entry_max; + r.min = entry_min; + r.sum = entry_sum; + r = new Result(); + } + else { + ret.count += entry_count; + ret.max = (short) Math.max(ret.max, entry_max); + ret.min = (short) Math.min(ret.min, entry_min); + ret.sum += entry_sum; + } + } + } + + public void dump_insert(long map_entry, long hash, long string_addr, byte string_size, long final_number) { + System.out.println("START dump_insert"); + System.out.println("Inserting " + final_number + " with hash " + hash); + System.out.println("Map entry: " + map_entry); + System.out.println("String addr: " + string_addr + " with length " + string_size); + dump(string_addr, string_addr + string_size); + System.out.println("END dump_insert"); + } + } + + static class Result { + public int count; + public short max; + public short min; + public long sum; + + private double round(double value) { + return Math.round(value * 10.0) / 10.0; + } + + @Override + public String toString() { + return round(min / 10.) + "/" + round(sum / (double) (10 * count)) + "/" + round(max / 10.); + } + } + + private static void compute_slice(final long base_addr, final long slice_size, final long file_size, final int thread_index) { + HashTable my_table; + if (!SINGLE_CORE) { + my_table = new HashTable(); + tables[thread_index] = my_table; + } + else { + if (tables[0] == null) { + tables[0] = new HashTable(); + } + my_table = tables[0]; + } + + long cur_addr = base_addr + (long) thread_index * slice_size; + // Lookup the next newline. If thread_index == 0 then start right away + if (thread_index != 0) { + while (unsafe.getByte(cur_addr) != '\n') { + cur_addr++; + } + cur_addr++; + } + + long end_addr = base_addr + (long) (thread_index + 1) * slice_size; + if (thread_index == (AVAIL_CORES - 1)) { + // Last thread. We need to read until the end of the file + end_addr = base_addr + file_size; + } + else { + // look ahead for the next newline + while (unsafe.getByte(end_addr) != '\n') { + end_addr++; + } + end_addr++; + } + + // We now have a well-defined interval [cur_addr, end_addr) to work on + long hash = -2346162244362633811L; + byte string_size = 0; + long string_addr = cur_addr; + while (cur_addr < end_addr) { + long value_mem = unsafe.getLong(cur_addr); + int semicolon_byte_index = get_semicolon_index(value_mem); + + string_size += (byte) semicolon_byte_index; + + // dump(cur_addr, cur_addr + semicolon_byte_index); + + if (semicolon_byte_index != 8) { + long value_mem_up_to_semicolon = value_mem & ((1L << (semicolon_byte_index * Byte.SIZE)) - 1); + + // We have a semicolon, so the hash is complete now. We can construct the number + // and insert it into the hash table + long start_num_addr = cur_addr + semicolon_byte_index + 1; + + // Always read the next 8 bytes for the number. It seems that this is faster than + // checking if the whole number is in the current 8 bytes and only reading if it is not + long number_mem_value = unsafe.getLong(start_num_addr); + long number_len_bytes = get_newline_index(number_mem_value); + + long final_number = extract_number(number_mem_value, number_len_bytes); + + // 0.2421196 % reprobe rate + hash = compute_hash(hash ^ value_mem_up_to_semicolon); + + // We have the final number now. We can insert it into the hash table + my_table.insert(hash, string_addr, string_size, final_number); + // Now we can move on to the next line + hash = -2346162244362633811L; + string_size = 0; + cur_addr = start_num_addr + number_len_bytes + 1; + string_addr = cur_addr; + } + else { + // No semicolon in the 8 bytes read. Continue reading + hash = hash ^ value_mem; + cur_addr += 8; + } + } + assert cur_addr == end_addr : String.format("Expected cur_addr to be %s, got %s", end_addr, cur_addr); + } + + private static long extract_number(long number_mem_value, long number_len_bytes) { + // Pray for GVN/CSE and Sea of Nodes moving the mess below in the proper places because + // I don't want to spend the time to do it properly :) + long number_mem_dot_index = get_dot_index(number_mem_value); + + int fractional_part = get_fractional_part(number_mem_value, number_len_bytes); + int sign = get_sign(number_mem_value); + int skip_sign = skip_sign(number_mem_value); + + long number_mem_value_no_sign = number_mem_value >>> (skip_sign << 3); + // Two cases: either there's a single digit before the dot, or there's two + // Start from the dot index and go backwards + long new_number_mem_dot_index = number_mem_dot_index - skip_sign; + long read_byte_mask = 0xFFL << ((new_number_mem_dot_index - 1) * Byte.SIZE); + long ones = ((number_mem_value_no_sign & read_byte_mask) >>> ((new_number_mem_dot_index - 1) * Byte.SIZE)) - 0x30; + // Should be 0 due to the multiplication if there's only one digit before the dot + long tens = ((number_mem_value_no_sign & 0xFFL) - 0x30) * (new_number_mem_dot_index - 1); + + long final_number = (tens * 100 + ones * 10 + fractional_part) * sign; + return final_number; + } + + private static int get_fractional_part(long number_mem_value, long number_len_bytes) { + return (int) ((number_mem_value >>> ((number_len_bytes - 1) * Byte.SIZE)) & 0xFF) - 0x30; + } + + private static int skip_sign(long number_mem_value) { + // return 1 if char is '-', 0 if it is not + long diff = (number_mem_value & 0xFF) - 0x2D; + long sign = (diff | -diff) >>> 63; + return (int) ((sign - 1) * -1); + } + + private static int get_sign(long number_mem_value) { + // return 1 if char is not '-', -1 if it is + long diff = (number_mem_value & 0xFF) - 0x2D; + long sign = (diff | -diff) >>> 63; + return (int) (-2 * sign + 1) * -1; + } + + private static long compute_hash(long x) { // Hash burrowed from artsiomkorzun and slightly changed + long h = x * -7046029254386353131L; + long h1 = h ^ (h >>> 32); + h = h ^ (h << 32); + return h1 ^ h; + } + + private static void dump(long startAddr, long endAddr) { + byte[] bytes = new byte[(int) (endAddr - startAddr)]; + unsafe.copyMemory(null, startAddr, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, bytes.length); + String s = new String(bytes, StandardCharsets.UTF_8); + System.out.println(s); + // Dump the bytes to binary form + for (byte b : bytes) { + System.out.print(Integer.toBinaryString(b & 0xFF)); + System.out.print(" "); + } + System.out.println(); + // Dump the bytes to hex form + for (byte b : bytes) { + System.out.print(Integer.toHexString(b & 0xFF)); + System.out.print(" "); + } + System.out.println(); + } + + private static int get_byte_0_index(long value) { + long res = (value - 0x0101010101010101L) & (~value & 0x8080808080808080L); + res = Long.numberOfTrailingZeros(res) >> 3; + return (int) res; + } + + private static int get_dot_index(long value) { + long temp = value ^ 0x2E2E2E2E2E2E2E2EL; + return get_byte_0_index(temp); + } + + private static int get_newline_index(long value) { + long temp = value ^ 0x0A0A0A0A0A0A0A0AL; + return get_byte_0_index(temp); + } + + private static int get_semicolon_index(long value) { + long temp = value ^ 0x3B3B3B3B3B3B3B3BL; + return get_byte_0_index(temp); + } + + private static final boolean SINGLE_CORE = false; + + public static void main(String[] args) throws IOException, InterruptedException { + FileChannel file_channel = FileChannel.open(Paths.get(FILE), StandardOpenOption.READ); + long file_size = file_channel.size(); + long base_addr = file_channel.map(FileChannel.MapMode.READ_ONLY, 0, file_size, Arena.global()).address(); + + if (!SINGLE_CORE) { + int num_threads = AVAIL_CORES; + Thread[] threads = new Thread[num_threads]; + for (int i = 0; i < num_threads; i++) { + int finalI = i; + threads[i] = new Thread(() -> { + long slice_size = file_size / AVAIL_CORES; + compute_slice(base_addr, slice_size, file_size, finalI); + }); + threads[i].start(); + } + + TreeMap result_map = new TreeMap<>(); + for (int i = 0; i < num_threads; i++) { + threads[i].join(); + tables[i].update_res(result_map); + } + + System.out.println(result_map); + } + else { + for (int i = 0; i < AVAIL_CORES; i++) { + int finalI = i; + long slice_size = file_size / AVAIL_CORES; + compute_slice(base_addr, slice_size, file_size, finalI); + } + + TreeMap result_map = new TreeMap<>(); + tables[0].update_res(result_map); + + System.out.println(result_map); + } + } +}