diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_hundredwatt.java b/src/main/java/dev/morling/onebrc/CalculateAverage_hundredwatt.java index 051de9c..9d935ff 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_hundredwatt.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_hundredwatt.java @@ -17,6 +17,7 @@ package dev.morling.onebrc; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.file.Files; @@ -36,7 +37,7 @@ public class CalculateAverage_hundredwatt { private static final long FILE_CHUNK_SIZE = CHUNK_SIZE - MAX_ROW_SIZE; public static final int TEMPERATURE_SLOTS = 5003; // prime number private static final short[] TEMPERATURES = new short[TEMPERATURE_SLOTS]; - private static final long PERFECT_HASH_SEED = -5353381455852817461L; + private static final long PERFECT_HASH_SEED = -1982870890352534081L; // Construct a perfect hash function mapping temperatures encoded as longs (e.g., 0x2d342e3000000000 for -4.3) to // the corresponding short integer (e.g., -43). @@ -46,17 +47,17 @@ public class CalculateAverage_hundredwatt { Map decodeTemperatureMap = new HashMap<>(); for (short i = -999; i <= 999; i++) { long word = 0; - int shift = 56; + int shift = 0; if (i < 0) { word |= ((long) '-') << shift; - shift -= 8; + shift += 8; } if (Math.abs(i) >= 100) { int hh = Math.abs(i) / 100; int tt = (Math.abs(i) - hh * 100) / 10; word |= ((long) (hh + '0')) << shift; - shift -= 8; + shift += 8; word |= ((long) (tt + '0')) << shift; } else { @@ -64,9 +65,9 @@ public class CalculateAverage_hundredwatt { // convert to ascii word |= ((long) (tt + '0')) << shift; } - shift -= 8; + shift += 8; word |= ((long) '.') << shift; - shift -= 8; + shift += 8; int uu = Math.abs(i) % 10; word |= ((long) (uu + '0')) << shift; @@ -74,31 +75,6 @@ public class CalculateAverage_hundredwatt { decodeTemperatureMap.put(word, i); } - // Brute force to find seed: - // Random rand = new Random(System.nanoTime()); - // int max = 0; - // int attempts = 0; - // while (true) { - // BitSet bs = new BitSet(5003); - // var seed = rand.nextLong(); - // seed |= 0b1; // make sure it's odd - // for (var word : decodeTemperatureMap.keySet()) { - // var h = (word * seed) & ~(1L << 63); - // var pos = (int) (h % 5003); - // bs.set(pos); - // } - // ; - // var c = bs.cardinality(); - // if (c == decodeTemperatureMap.size()) { - // System.out.println("seed: " + seed + " cardinality: " + c + " max cardinality: " + max); - // break; - // } - // max = Math.max(max, c); - // if (attempts % 100_000 == 0) - // System.out.println("seed: " + seed + " cardinality: " + c + " max cardinality: " + max); - // attempts++; - // } - decodeTemperatureMap.entrySet().stream().forEach(e -> { var word = e.getKey(); var h = (word * PERFECT_HASH_SEED) & ~(1L << 63); @@ -144,11 +120,12 @@ public class CalculateAverage_hundredwatt { } static class HashTable { - private static final int INITIAL_SIZE = 128 * 1024; + private static final int INITIAL_SIZE = 16 * 1024; private static final float LOAD_FACTOR = 0.75f; private static final int GROW_FACTOR = 4; private final long[][] KEYS = new long[INITIAL_SIZE][]; private final Record[] VALUES = new Record[INITIAL_SIZE]; + private final long[] HASHES = new long[INITIAL_SIZE]; private int size = INITIAL_SIZE; public HashTable() { @@ -162,13 +139,14 @@ public class CalculateAverage_hundredwatt { // linear probing int i = 0; - while (KEYS[idx] != null && (0 != Arrays.compareUnsigned(KEYS[idx], 0, KEYS[idx].length, key, 0, length))) { + while (KEYS[idx] != null && (HASHES[idx] != hash) && (0 != Arrays.compareUnsigned(KEYS[idx], 0, KEYS[idx].length, key, 0, length))) { i++; idx = (idx + 1) & (size - 1); } if (KEYS[idx] == null) { KEYS[idx] = Arrays.copyOf(key, length); + HASHES[idx] = hash; } VALUES[idx].updateWith(value); @@ -186,7 +164,7 @@ public class CalculateAverage_hundredwatt { } private static String keyToString(long[] key) { - ByteBuffer kb = ByteBuffer.allocate(8 * key.length); + ByteBuffer kb = ByteBuffer.allocate(8 * key.length).order(ByteOrder.LITTLE_ENDIAN); Arrays.stream(key).forEach(kb::putLong); // remove trailing '\0' bytes from kb and @@ -194,7 +172,7 @@ public class CalculateAverage_hundredwatt { byte b; int limit = kb.position() - 8; kb.position(limit); - while ((b = kb.get()) != 0 && b != ';') { + while ((b = kb.get()) != 0 && b != ';' && limit < kb.capacity() - 1) { limit++; } @@ -202,7 +180,7 @@ public class CalculateAverage_hundredwatt { byte[] bytes = new byte[limit]; kb.get(bytes); - return new String(bytes).replace("\0", ""); + return new String(bytes); } private static Record merge(Record v, Record value) { @@ -215,6 +193,7 @@ public class CalculateAverage_hundredwatt { } private static int processChunk(ByteBuffer bb, HashTable hashTable, long start, long size) { + bb.order(ByteOrder.LITTLE_ENDIAN); // Find first entry while (start != 0 && bb.get() != '\n') { } @@ -228,6 +207,7 @@ public class CalculateAverage_hundredwatt { long temperature_hash; int temperature_pos; short temperature_value; + int hashInt; int i = 0; int end = (int) (size - MAX_ROW_SIZE); @@ -258,29 +238,28 @@ public class CalculateAverage_hundredwatt { hasvalue = (op1 & op2 & 0x8080808080808080L); } hash ^= key[offset]; // unset last word since it will be updated - key[offset] = key[offset] & (-hasvalue); // ';' == 0x3b and -hasvalue is something like 0xff8000..., we can ignore the 0x80 byte since 0x3b & 0x80 == 0 + key[offset] = key[offset] & ~(-(hasvalue >> 7)); hash ^= key[offset]; - position = position + offset * 8 + (Long.numberOfLeadingZeros(hasvalue)) / 8 + 1; // +1 for \n + position = position + offset * 8 + Long.numberOfTrailingZeros(hasvalue) / 8 + 1; // +1 for \n // Parse temperature word = bb.getLong(position); - arg = (word) ^ 0x0101010101010101L * ('\n'); - op1 = (arg - 0x0101010101010101L); - op2 = ~(arg); - hasvalue = (op1 & op2 & 0x8080808080808080L); - word = word & ((-hasvalue)); + hasvalue = (word - 0x0B0B0B0B0B0B0B0BL) & 0x8080808080808080L; + int newlinePos = Long.numberOfTrailingZeros(hasvalue) - 8; + + word = word & (~(-(1L << newlinePos))); // Perfect hash lookup for temperature temperature_hash = (word * PERFECT_HASH_SEED) & ~(1L << 63); temperature_pos = (int) (temperature_hash % TEMPERATURE_SLOTS); temperature_value = TEMPERATURES[temperature_pos]; - position = position + (Long.numberOfLeadingZeros(hasvalue)) / 8 + 1; // +1 for \n + position = position + newlinePos / 8 + 2; // +1 for \n - int hash2 = (int) (hash ^ (hash >> 32)); + hashInt = (int) (hash ^ (hash >> 32)); - hashTable.putOrMerge(hash2, offset + 1, key, temperature_value); + hashTable.putOrMerge(hashInt, offset + 1, key, temperature_value); } return i; } diff --git a/src/main/java/dev/morling/onebrc/PerfectHashSearch_hundredwatt.java b/src/main/java/dev/morling/onebrc/PerfectHashSearch_hundredwatt.java new file mode 100644 index 0000000..18c1926 --- /dev/null +++ b/src/main/java/dev/morling/onebrc/PerfectHashSearch_hundredwatt.java @@ -0,0 +1,170 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.BitSet; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.random.RandomGenerator; +import java.util.random.RandomGeneratorFactory; + +/** Offline script used to find the perfect hash seed for CalculateAverage_hundredwatt. */ +public class PerfectHashSearch_hundredwatt { + public static final int DESIRED_SLOTS = 5003; + public static final int N_THREADS = Runtime.getRuntime().availableProcessors() - 1; + + public static void main(String[] args) throws IOException, InterruptedException { + AtomicLong magicSeed = new AtomicLong(0); + AtomicLong totalAttempts = new AtomicLong(0); + AtomicLong maxCardinality = new AtomicLong(0); + + long start = System.currentTimeMillis(); + + System.out.println("Searching for perfect hash seed for " + DESIRED_SLOTS + " slots"); + + // Figure out encoding for all possible temperature values (1999 total) + Map decodeTemperatureMap = new HashMap<>(); + for (short i = -999; i <= 999; i++) { + long word = 0; + int shift = 0; + if (i < 0) { + word |= ((long) '-') << shift; + shift += 8; + } + if (Math.abs(i) >= 100) { + int hh = Math.abs(i) / 100; + int tt = (Math.abs(i) - hh * 100) / 10; + + word |= ((long) (hh + '0')) << shift; + shift += 8; + word |= ((long) (tt + '0')) << shift; + } + else { + int tt = Math.abs(i) / 10; + // convert to ascii + word |= ((long) (tt + '0')) << shift; + } + shift += 8; + word |= ((long) '.') << shift; + shift += 8; + int uu = Math.abs(i) % 10; + word |= ((long) (uu + '0')) << shift; + + // 31302e3000000000 + decodeTemperatureMap.put(word, i); + } + + ExecutorService executor = Executors.newFixedThreadPool(N_THREADS); + + RandomGeneratorFactory factory = RandomGeneratorFactory.of("L64X256MixRandom"); + + Runnable search = () -> { + // Brute force to find seed: + // generate a cryptographically secure random seed + RandomGenerator rand; + try { + byte[] seed = new byte[16]; + SecureRandom.getInstanceStrong().nextBytes(seed); + rand = factory.create(ByteBuffer.wrap(seed).getLong()); + System.out.println(Thread.currentThread().getName() + " | Using seed: " + rand.nextLong()); + } + catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + + int max = 0; + int attempts = 0; + while (true) { + BitSet bs = new BitSet(DESIRED_SLOTS); + var seed = rand.nextLong(); + seed |= 0b1; // make sure it's odd + for (var word : decodeTemperatureMap.keySet()) { + var h = (word * seed) & ~(1L << 63); + var pos = (int) (h % DESIRED_SLOTS); + bs.set(pos); + } + var c = bs.cardinality(); + if (c == decodeTemperatureMap.size()) { + System.out.println("FOUND seed: " + seed + " cardinality: " + c + " max cardinality: " + max); + magicSeed.set(seed); + return; + } + max = Math.max(max, c); + if (attempts % 100_000 == 0) { + if (magicSeed.get() != 0) + return; + int finalMax = max; + long currentMaxCardinality = maxCardinality.updateAndGet(currentMax -> Math.max(currentMax, finalMax)); + long currentTotalAttempts = totalAttempts.addAndGet(100_000); + + if (Thread.currentThread().getName().endsWith("-1")) + System.out.println(Thread.currentThread().getName() + " | max cardinality: " + currentMaxCardinality + " attempts: " + + String.format("%,d", currentTotalAttempts)); + } + attempts++; + } + }; + + for (int i = 0; i < Runtime.getRuntime().availableProcessors(); i++) { + executor.submit(search); + } + + // Wait for the search to complete + executor.shutdown(); + executor.awaitTermination(1, TimeUnit.DAYS); + + short[] TEMPERATURES = new short[DESIRED_SLOTS]; + long seed = magicSeed.get(); + + decodeTemperatureMap.entrySet().stream().forEach(e -> { + var word = e.getKey(); + var h = (word * seed) & ~(1L << 63); + var pos = (int) (h % DESIRED_SLOTS); + if (TEMPERATURES[pos] != 0) + throw new RuntimeException("collision at " + pos); + TEMPERATURES[pos] = e.getValue(); + }); + System.out.println("SUCCESS seed: " + seed + " total attempts: " + totalAttempts.get()); + + try { + File file = new File("seeds.txt"); + file.delete(); + file.createNewFile(); + + // Write the seed to seeds.txt + FileWriter myWriter = new FileWriter("seeds.txt"); + myWriter.write(Long.toString(seed)); + myWriter.write("\n"); + myWriter.close(); + + } + catch (IOException e) { + throw new RuntimeException(e); + } + + System.out.println("Search took " + ((System.currentTimeMillis() - start) / 1000) + "s"); + } +}