Fixing the off-by-one error and updating to native, redone layout of code. (#307)
This commit is contained in:
		| @@ -15,5 +15,16 @@ | |||||||
| #  limitations under the License. | #  limitations under the License. | ||||||
| # | # | ||||||
|  |  | ||||||
| JAVA_OPTS="--enable-preview" | if [ -f target/CalculateAverage_royvanrijn_image ]; then | ||||||
| java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_royvanrijn |     echo "Picking up existing native image 'target/CalculateAverage_royvanrijn_image', delete the file to select JVM mode." 1>&2 | ||||||
|  |     target/CalculateAverage_royvanrijn_image | ||||||
|  | else | ||||||
|  |     JAVA_OPTS="--enable-preview -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA" | ||||||
|  |     if [[ ! "$(uname -s)" = "Darwin" ]]; then | ||||||
|  |         # On OS/X, my machine, this errors: | ||||||
|  |         JAVA_OPTS="$JAVA_OPTS -XX:+UseTransparentHugePages" | ||||||
|  |     fi | ||||||
|  |     echo "Choosing to run the app in JVM mode as no native image was found, use additional_build_step_royvanrijn.sh to generate." 1>&2 | ||||||
|  |     java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_royvanrijn | ||||||
|  | fi | ||||||
|  |  | ||||||
|   | |||||||
| @@ -17,3 +17,12 @@ | |||||||
|  |  | ||||||
| source "$HOME/.sdkman/bin/sdkman-init.sh" | source "$HOME/.sdkman/bin/sdkman-init.sh" | ||||||
| sdk use java 21.0.1-graal 1>&2 | sdk use java 21.0.1-graal 1>&2 | ||||||
|  |  | ||||||
|  | # ./mvnw clean verify removes target/ and will re-trigger native image creation. | ||||||
|  | if [ ! -f target/CalculateAverage_royvanrijn_image ]; then | ||||||
|  |  | ||||||
|  |     JAVA_OPTS="--enable-preview -dsa" | ||||||
|  |     NATIVE_IMAGE_OPTS="--gc=epsilon -Ob -O3 -march=native --strict-image-heap $JAVA_OPTS" | ||||||
|  |  | ||||||
|  |     native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_royvanrijn_image dev.morling.onebrc.CalculateAverage_royvanrijn | ||||||
|  | fi | ||||||
|   | |||||||
| @@ -18,16 +18,15 @@ package dev.morling.onebrc; | |||||||
| import java.io.IOException; | import java.io.IOException; | ||||||
| import java.lang.foreign.Arena; | import java.lang.foreign.Arena; | ||||||
| import java.lang.reflect.Field; | import java.lang.reflect.Field; | ||||||
| import java.nio.ByteOrder; |  | ||||||
| import java.nio.channels.FileChannel; | import java.nio.channels.FileChannel; | ||||||
|  | import java.nio.charset.StandardCharsets; | ||||||
| import java.nio.file.Path; | import java.nio.file.Path; | ||||||
| import java.nio.file.StandardOpenOption; | import java.nio.file.StandardOpenOption; | ||||||
| import java.util.Arrays; | import java.util.HashMap; | ||||||
| import java.util.Objects; | import java.util.List; | ||||||
| import java.util.TreeMap; | import java.util.Map; | ||||||
| import java.util.stream.Collectors; | import java.util.stream.Collectors; | ||||||
| import java.util.stream.IntStream; | import java.util.stream.IntStream; | ||||||
| import java.util.stream.Stream; |  | ||||||
|  |  | ||||||
| import sun.misc.Unsafe; | import sun.misc.Unsafe; | ||||||
|  |  | ||||||
| @@ -49,21 +48,22 @@ import sun.misc.Unsafe; | |||||||
|  * Inlining hash calculation:        2450 ms |  * Inlining hash calculation:        2450 ms | ||||||
|  * Replacing branchless code:        2200 ms (sometimes we need to kill the things we love) |  * Replacing branchless code:        2200 ms (sometimes we need to kill the things we love) | ||||||
|  * Added unsafe memory access:       1900 ms (keeping the long[] small and local) |  * Added unsafe memory access:       1900 ms (keeping the long[] small and local) | ||||||
|  |  * Fixed bug, UNSAFE bytes String:   1850 ms | ||||||
|  |  * Separate hash from entries:       1550 ms | ||||||
|  |  * Various tweaks for Linux/cache    1550 ms (should/could make a difference on target machine) | ||||||
|  |  * Improved layout/predictability    1450 ms (on par with Thomas Wuerthinger) | ||||||
|  * |  * | ||||||
|  * Best performing JVM on MacBook M2 Pro: 21.0.1-graal |  * Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai for ideas. | ||||||
|  * `sdk use java 21.0.1-graal` |  | ||||||
|  * |  | ||||||
|  */ |  */ | ||||||
| public class CalculateAverage_royvanrijn { | public class CalculateAverage_royvanrijn { | ||||||
|  |  | ||||||
|     private static final String FILE = "./measurements.txt"; |     private static final String FILE = "./measurements.txt"; | ||||||
|  |  | ||||||
|     private static final Unsafe UNSAFE = initUnsafe(); |     private static final Unsafe UNSAFE = initUnsafe(); | ||||||
|     private static final boolean isBigEndian = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); |  | ||||||
|  |  | ||||||
|     private static Unsafe initUnsafe() { |     private static Unsafe initUnsafe() { | ||||||
|         try { |         try { | ||||||
|             Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); |             final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); | ||||||
|             theUnsafe.setAccessible(true); |             theUnsafe.setAccessible(true); | ||||||
|             return (Unsafe) theUnsafe.get(Unsafe.class); |             return (Unsafe) theUnsafe.get(Unsafe.class); | ||||||
|         } |         } | ||||||
| @@ -73,32 +73,42 @@ public class CalculateAverage_royvanrijn { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     public static void main(String[] args) throws Exception { |     public static void main(String[] args) throws Exception { | ||||||
|         new CalculateAverage_royvanrijn().run(); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     public void run() throws Exception { |  | ||||||
|  |  | ||||||
|         // Calculate input segments. |         // Calculate input segments. | ||||||
|         int numberOfChunks = Runtime.getRuntime().availableProcessors(); |         final int numberOfChunks = Runtime.getRuntime().availableProcessors(); | ||||||
|         long[] chunks = getSegments(numberOfChunks); |         final long[] chunks = getSegments(numberOfChunks); | ||||||
|  |  | ||||||
|         // Parallel processing of segments. |         final List<Entry[]> repositories = IntStream.range(0, chunks.length - 1) | ||||||
|         TreeMap<String, Measurement> results = IntStream.range(0, chunks.length - 1) |                 .mapToObj(chunkIndex -> processMemoryArea(chunks[chunkIndex], chunks[chunkIndex + 1])) | ||||||
|                 .mapToObj(chunkIndex -> process(chunks[chunkIndex], chunks[chunkIndex + 1])).parallel() |                 .parallel() | ||||||
|                 .flatMap(MeasurementRepository::get) |                 .toList(); | ||||||
|                 .collect(Collectors.toMap(e -> e.city, MeasurementRepository.Entry::measurement, Measurement::updateWith, TreeMap::new)); |  | ||||||
|  |  | ||||||
|         System.out.println(results); |         // Sometimes simple is better: | ||||||
|  |         final HashMap<String, Entry> measurements = HashMap.newHashMap(1 << 10); | ||||||
|  |         for (Entry[] entries : repositories) { | ||||||
|  |             for (Entry entry : entries) { | ||||||
|  |                 if (entry != null) | ||||||
|  |                     measurements.merge(entry.city, entry, Entry::mergeWith); | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|     private static long[] getSegments(int numberOfChunks) throws IOException { |         System.out.print("{" + | ||||||
|  |                 measurements.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", "))); | ||||||
|  |         System.out.println("}"); | ||||||
|  |  | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * Simpler way to get the segments and launch parallel processing by thomaswue | ||||||
|  |      */ | ||||||
|  |     private static long[] getSegments(final int numberOfChunks) throws IOException { | ||||||
|         try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { |         try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { | ||||||
|             long fileSize = fileChannel.size(); |             final long fileSize = fileChannel.size(); | ||||||
|             long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks; |             final long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks; | ||||||
|             long[] chunks = new long[numberOfChunks + 1]; |             final long[] chunks = new long[numberOfChunks + 1]; | ||||||
|             long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); |             final long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); | ||||||
|             chunks[0] = mappedAddress; |             chunks[0] = mappedAddress; | ||||||
|             long endAddress = mappedAddress + fileSize; |             final long endAddress = mappedAddress + fileSize; | ||||||
|             for (int i = 1; i < numberOfChunks; ++i) { |             for (int i = 1; i < numberOfChunks; ++i) { | ||||||
|                 long chunkAddress = mappedAddress + i * segmentSize; |                 long chunkAddress = mappedAddress + i * segmentSize; | ||||||
|                 // Align to first row start. |                 // Align to first row start. | ||||||
| @@ -112,108 +122,36 @@ public class CalculateAverage_royvanrijn { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     private MeasurementRepository process(long fromAddress, long toAddress) { |     private static final int TABLE_SIZE = 1 << 18; // large enough for the contest. | ||||||
|  |     private static final int TABLE_MASK = (TABLE_SIZE - 1); | ||||||
|  |  | ||||||
|         MeasurementRepository repository = new MeasurementRepository(); |     static final class Entry { | ||||||
|         long ptr = fromAddress; |         private final long[] data; | ||||||
|         long[] dataBuffer = new long[16]; |         private final String city; | ||||||
|         while ((ptr = processEntity(dataBuffer, ptr, toAddress, repository)) < toAddress) |         private int min, max, count; | ||||||
|             ; |         private long sum; | ||||||
|  |  | ||||||
|         return repository; |         Entry(final long[] data, String city, int temp) { | ||||||
|  |             this.data = data; | ||||||
|  |             this.city = city; | ||||||
|  |             this.min = temp; | ||||||
|  |             this.max = temp; | ||||||
|  |             this.sum = temp; | ||||||
|  |             this.count = 1; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|     private static final long SEPARATOR_PATTERN = compilePattern((byte) ';'); |         public void updateWith(int measurement) { | ||||||
|  |             min = Math.min(min, measurement); | ||||||
|     /** |             max = Math.max(max, measurement); | ||||||
|      * Already looping the longs here, lets shoehorn in making a hash |  | ||||||
|      */ |  | ||||||
|     private long processEntity(final long[] data, final long start, final long limit, final MeasurementRepository measurementRepository) { |  | ||||||
|         int hash = 1; |  | ||||||
|         long i; |  | ||||||
|         int dataPtr = 0; |  | ||||||
|         for (i = start; i <= limit - 8; i += 8) { |  | ||||||
|             long word = UNSAFE.getLong(i); |  | ||||||
|             if (isBigEndian) { |  | ||||||
|                 word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this |  | ||||||
|             } |  | ||||||
|             final long match = word ^ SEPARATOR_PATTERN; |  | ||||||
|             long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L; |  | ||||||
|  |  | ||||||
|             if (mask != 0) { |  | ||||||
|  |  | ||||||
|                 final long partialWord = word & ((mask >> 7) - 1); |  | ||||||
|                 hash = longHashStep(hash, partialWord); |  | ||||||
|                 data[dataPtr] = partialWord; |  | ||||||
|  |  | ||||||
|                 final int index = Long.numberOfTrailingZeros(mask) >> 3; |  | ||||||
|                 return process(start, i + index, hash, data, measurementRepository); |  | ||||||
|             } |  | ||||||
|             data[dataPtr++] = word; |  | ||||||
|             hash = longHashStep(hash, word); |  | ||||||
|         } |  | ||||||
|         // Handle remaining bytes near the limit of the buffer: |  | ||||||
|         long partialWord = 0; |  | ||||||
|         int len = 0; |  | ||||||
|         for (; i < limit; i++) { |  | ||||||
|             byte read; |  | ||||||
|             if ((read = UNSAFE.getByte(i)) == ';') { |  | ||||||
|                 hash = longHashStep(hash, partialWord); |  | ||||||
|                 data[dataPtr] = partialWord; |  | ||||||
|                 return process(start, i, hash, data, measurementRepository); |  | ||||||
|             } |  | ||||||
|             partialWord = partialWord | ((long) read << (len << 3)); |  | ||||||
|             len++; |  | ||||||
|         } |  | ||||||
|         return limit; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     private static final long DOT_BITS = 0x10101000; |  | ||||||
|     private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); |  | ||||||
|  |  | ||||||
|     private long process(final long startAddress, final long delimiterAddress, final int hash, final long[] data, final MeasurementRepository measurementRepository) { |  | ||||||
|  |  | ||||||
|         long word = UNSAFE.getLong(delimiterAddress + 1); |  | ||||||
|         if (isBigEndian) { |  | ||||||
|             word = Long.reverseBytes(word); |  | ||||||
|         } |  | ||||||
|         final long invWord = ~word; |  | ||||||
|         final int decimalSepPos = Long.numberOfTrailingZeros(invWord & DOT_BITS); |  | ||||||
|         final long signed = (invWord << 59) >> 63; |  | ||||||
|         final long designMask = ~(signed & 0xFF); |  | ||||||
|         final long digits = ((word & designMask) << (28 - decimalSepPos)) & 0x0F000F0F00L; |  | ||||||
|         final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; |  | ||||||
|         final int measurement = (int) ((absValue ^ signed) - signed); |  | ||||||
|  |  | ||||||
|         // Store: |  | ||||||
|         measurementRepository.update(startAddress, data, (int) (delimiterAddress - startAddress), hash, measurement); |  | ||||||
|  |  | ||||||
|         return delimiterAddress + (decimalSepPos >> 3) + 4; // Determine next start: |  | ||||||
|         // return nextAddress; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     static final class Measurement { |  | ||||||
|         int min, max, count; |  | ||||||
|         long sum; |  | ||||||
|  |  | ||||||
|         public Measurement() { |  | ||||||
|             this.min = 1000; |  | ||||||
|             this.max = -1000; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         public Measurement updateWith(int measurement) { |  | ||||||
|             min = min(min, measurement); |  | ||||||
|             max = max(max, measurement); |  | ||||||
|             sum += measurement; |             sum += measurement; | ||||||
|             count++; |             count++; | ||||||
|             return this; |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         public Measurement updateWith(Measurement measurement) { |         public Entry mergeWith(Entry entry) { | ||||||
|             min = min(min, measurement.min); |             min = Math.min(min, entry.min); | ||||||
|             max = max(max, measurement.max); |             max = Math.max(max, entry.max); | ||||||
|             sum += measurement.sum; |             sum += entry.sum; | ||||||
|             count += measurement.count; |             count += entry.count; | ||||||
|             return this; |             return this; | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -221,101 +159,127 @@ public class CalculateAverage_royvanrijn { | |||||||
|             return round(min) + "/" + round((1.0 * sum) / count) + "/" + round(max); |             return round(min) + "/" + round((1.0 * sum) / count) + "/" + round(max); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         private double round(double value) { |         private static double round(double value) { | ||||||
|             return Math.round(value) / 10.0; |             return Math.round(value) / 10.0; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // branchless max (unprecise for large numbers, but good enough) |     private static Entry createNewEntry(final long[] buffer, final long startAddress, final int lengthLongs, final int lengthBytes, final int temp) { | ||||||
|     static int max(final int a, final int b) { |  | ||||||
|         final int diff = a - b; |         // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here. | ||||||
|         final int dsgn = diff >> 31; |         final byte[] bytes = new byte[lengthBytes]; | ||||||
|         return a - (diff & dsgn); |         UNSAFE.copyMemory(null, startAddress, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, lengthBytes); | ||||||
|  |         final String city = new String(bytes, StandardCharsets.UTF_8); | ||||||
|  |  | ||||||
|  |         final long[] bufferCopy = new long[lengthLongs]; | ||||||
|  |         System.arraycopy(buffer, 0, bufferCopy, 0, lengthLongs); | ||||||
|  |  | ||||||
|  |         // Add the entry: | ||||||
|  |         return new Entry(bufferCopy, city, temp); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // branchless min (unprecise for large numbers, but good enough) |     private static Entry[] processMemoryArea(final long fromAddress, final long toAddress) { | ||||||
|     static int min(final int a, final int b) { |  | ||||||
|         final int diff = a - b; |         Entry[] table = new Entry[TABLE_SIZE]; | ||||||
|         final int dsgn = diff >> 31; |  | ||||||
|         return b + (diff & dsgn); |         long ptr = fromAddress; | ||||||
|  |         long[] buffer = new long[14]; | ||||||
|  |  | ||||||
|  |         while (ptr < toAddress) { | ||||||
|  |  | ||||||
|  |             int bufferPtr = 0; | ||||||
|  |             long startAddress = ptr; | ||||||
|  |             long hash = 1; | ||||||
|  |  | ||||||
|  |             long word = UNSAFE.getLong(ptr); | ||||||
|  |             long mask = getDelimiterMask(word); | ||||||
|  |  | ||||||
|  |             while (mask == 0) { | ||||||
|  |                 buffer[bufferPtr++] = word; | ||||||
|  |                 hash ^= word; | ||||||
|  |                 ptr += 8; | ||||||
|  |  | ||||||
|  |                 word = UNSAFE.getLong(ptr); | ||||||
|  |                 mask = getDelimiterMask(word); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|     private static int longHashStep(final int hash, final long word) { |             // Found delimiter: | ||||||
|         return 31 * hash + (int) (word ^ (word >>> 32)); |             final long delimiterAddress = ptr + (Long.numberOfTrailingZeros(mask) >> 3); | ||||||
|  |             final long numberBits = UNSAFE.getLong(delimiterAddress + 1); | ||||||
|  |  | ||||||
|  |             // Finish the masks and hash: | ||||||
|  |             final long partialWord = word & ((mask >> 7) - 1); | ||||||
|  |             buffer[bufferPtr++] = partialWord; | ||||||
|  |             hash ^= partialWord; | ||||||
|  |  | ||||||
|  |             final long invNumberBits = ~numberBits; | ||||||
|  |             final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBits & DOT_BITS); | ||||||
|  |  | ||||||
|  |             // Update counter asap, lets CPU predict. | ||||||
|  |             ptr = delimiterAddress + (decimalSepPos >> 3) + 4; | ||||||
|  |  | ||||||
|  |             int intHash = (int) (hash ^ (hash >>> 31)); // offset for extra entropy | ||||||
|  |  | ||||||
|  |             // Awesome idea of merykitty: | ||||||
|  |             final int temp = extractTemp(numberBits, invNumberBits, decimalSepPos); | ||||||
|  |  | ||||||
|  |             int index = intHash & TABLE_MASK; | ||||||
|  |  | ||||||
|  |             // Find or insert the entry: | ||||||
|  |             while (true) { | ||||||
|  |                 Entry tableEntry = table[index]; | ||||||
|  |                 if (tableEntry == null) { | ||||||
|  |                     final int length = (int) (delimiterAddress - startAddress); | ||||||
|  |                     table[index] = createNewEntry(buffer, startAddress, bufferPtr, length, temp); | ||||||
|  |                     break; | ||||||
|                 } |                 } | ||||||
|  |                 else if (bufferPtr == tableEntry.data.length) { | ||||||
|  |                     if (!arrayEquals(buffer, tableEntry.data, bufferPtr)) { | ||||||
|  |                         index = (index + 1) & TABLE_MASK; | ||||||
|  |                         continue; | ||||||
|  |                     } | ||||||
|  |                     // No differences in array | ||||||
|  |                     tableEntry.updateWith(temp); | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |                 // Move to the next index | ||||||
|  |                 index = (index + 1) & TABLE_MASK; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         return table; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     private static int extractTemp(final long numberBits, final long invNumberBits, final int decimalSepPos) { | ||||||
|  |         final long signed = (invNumberBits << 59) >> 63; | ||||||
|  |         final long minusFilter = ~(signed & 0xFF); | ||||||
|  |         final long digits = ((numberBits & minusFilter) << (28 - decimalSepPos)) & 0x0F000F0F00L; | ||||||
|  |         final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; // filter just the result | ||||||
|  |         final int temp = (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick | ||||||
|  |         return temp; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     private static long getDelimiterMask(final long word) { | ||||||
|  |         long match = word ^ SEPARATOR_PATTERN; | ||||||
|  |         return (match - 0x0101010101010101L) & ~match & 0x8080808080808080L; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     private static final long SEPARATOR_PATTERN = compilePattern((byte) ';'); | ||||||
|  |     private static final long DOT_BITS = 0x10101000; | ||||||
|  |     private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); | ||||||
|  |  | ||||||
|     private static long compilePattern(final byte value) { |     private static long compilePattern(final byte value) { | ||||||
|         return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) | |         return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) | | ||||||
|                 ((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value; |                 ((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /** |  | ||||||
|      * A normal Java HashMap does all these safety things like boundary checks... we don't need that, we need speeeed. |  | ||||||
|      * |  | ||||||
|      * So I've written an extremely simple linear probing hashmap that should work well enough. |  | ||||||
|      */ |  | ||||||
|     class MeasurementRepository { |  | ||||||
|         private int tableSize = 1 << 20; // large enough for the contest. |  | ||||||
|         private int tableMask = (tableSize - 1); |  | ||||||
|  |  | ||||||
|         private MeasurementRepository.Entry[] table = new MeasurementRepository.Entry[tableSize]; |  | ||||||
|  |  | ||||||
|         record Entry(long address, long[] data, int length, int hash, String city, Measurement measurement) { |  | ||||||
|  |  | ||||||
|             @Override |  | ||||||
|             public String toString() { |  | ||||||
|                 return city + "=" + measurement; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         public void update(long address, long[] data, int length, int hash, int temperature) { |  | ||||||
|  |  | ||||||
|             int dataLength = length >> 3; |  | ||||||
|             int index = hash & tableMask; |  | ||||||
|             MeasurementRepository.Entry tableEntry; |  | ||||||
|             while ((tableEntry = table[index]) != null |  | ||||||
|                     && (tableEntry.hash != hash || tableEntry.length != length || !arrayEquals(tableEntry.data, data, dataLength))) { // search for the right spot |  | ||||||
|                 index = (index + 1) & tableMask; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             if (tableEntry != null) { |  | ||||||
|                 tableEntry.measurement.updateWith(temperature); |  | ||||||
|                 return; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here. |  | ||||||
|             Measurement measurement = new Measurement(); |  | ||||||
|  |  | ||||||
|             byte[] bytes = new byte[length]; |  | ||||||
|             for (int i = 0; i < length; i++) { |  | ||||||
|                 bytes[i] = UNSAFE.getByte(address + i); |  | ||||||
|             } |  | ||||||
|             String city = new String(bytes); |  | ||||||
|  |  | ||||||
|             long[] dataCopy = new long[dataLength]; |  | ||||||
|             System.arraycopy(data, 0, dataCopy, 0, dataLength); |  | ||||||
|  |  | ||||||
|             // And add entry: |  | ||||||
|             MeasurementRepository.Entry toAdd = new MeasurementRepository.Entry(address, dataCopy, length, hash, city, measurement); |  | ||||||
|             table[index] = toAdd; |  | ||||||
|  |  | ||||||
|             toAdd.measurement.updateWith(temperature); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         public Stream<MeasurementRepository.Entry> get() { |  | ||||||
|             return Arrays.stream(table).filter(Objects::nonNull); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * For case multiple hashes are equal (however unlikely) check the actual key (using longs) |      * For case multiple hashes are equal (however unlikely) check the actual key (using longs) | ||||||
|      */ |      */ | ||||||
|     private boolean arrayEquals(final long[] a, final long[] b, final int length) { |     static boolean arrayEquals(final long[] a, final long[] b, final int length) { | ||||||
|         for (int i = 0; i < length; i++) { |         for (int i = 0; i < length; i++) { | ||||||
|             if (a[i] != b[i]) |             if (a[i] != b[i]) | ||||||
|                 return false; |                 return false; | ||||||
|         } |         } | ||||||
|         return true; |         return true; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user