Locally another 5% faster, much faster for larger set, made more general (#352)
This commit is contained in:
parent
bd4cff945d
commit
9227aa5062
@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets;
|
|||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.StandardOpenOption;
|
import java.nio.file.StandardOpenOption;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.IntStream;
|
import java.util.stream.IntStream;
|
||||||
@ -53,8 +52,13 @@ import sun.misc.Unsafe;
|
|||||||
* Various tweaks for Linux/cache 1550 ms (should/could make a difference on target machine)
|
* Various tweaks for Linux/cache 1550 ms (should/could make a difference on target machine)
|
||||||
* Improved layout/predictability: 1400 ms
|
* Improved layout/predictability: 1400 ms
|
||||||
* Delayed String creation again: 1350 ms
|
* Delayed String creation again: 1350 ms
|
||||||
|
* Remove writing to buffer: 1335 ms
|
||||||
|
* Optimized collecting at the end: 1310 ms
|
||||||
|
* Adding a lot of comments: priceless
|
||||||
*
|
*
|
||||||
* Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai for ideas.
|
* Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai for ideas.
|
||||||
|
*
|
||||||
|
* Follow me at: @royvanrijn
|
||||||
*/
|
*/
|
||||||
public class CalculateAverage_royvanrijn {
|
public class CalculateAverage_royvanrijn {
|
||||||
|
|
||||||
@ -74,29 +78,24 @@ public class CalculateAverage_royvanrijn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
|
|
||||||
// Calculate input segments.
|
// Calculate input segments.
|
||||||
final int numberOfChunks = Runtime.getRuntime().availableProcessors();
|
final int numberOfChunks = Runtime.getRuntime().availableProcessors();
|
||||||
final long[] chunks = getSegments(numberOfChunks);
|
final long[] chunks = getSegments(numberOfChunks);
|
||||||
|
|
||||||
final List<Entry[]> repositories = IntStream.range(0, chunks.length - 1)
|
final Map<String, Entry> measurements = HashMap.newHashMap(1 << 10);
|
||||||
|
IntStream.range(0, chunks.length - 1)
|
||||||
.mapToObj(chunkIndex -> processMemoryArea(chunks[chunkIndex], chunks[chunkIndex + 1]))
|
.mapToObj(chunkIndex -> processMemoryArea(chunks[chunkIndex], chunks[chunkIndex + 1]))
|
||||||
.parallel()
|
.parallel()
|
||||||
.toList();
|
.forEachOrdered(repo -> { // make sure it's ordered, no concurrent map
|
||||||
|
for (Entry entry : repo) {
|
||||||
// Sometimes simple is better:
|
|
||||||
final HashMap<String, Entry> measurements = HashMap.newHashMap(1 << 10);
|
|
||||||
for (Entry[] entries : repositories) {
|
|
||||||
for (Entry entry : entries) {
|
|
||||||
if (entry != null)
|
if (entry != null)
|
||||||
measurements.merge(extractedCityFromLongArray(entry.data, entry.length), entry, Entry::mergeWith);
|
measurements.merge(turnLongArrayIntoString(entry.data, entry.length), entry, Entry::mergeWith);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
System.out.print("{" +
|
System.out.print("{" +
|
||||||
measurements.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", ")));
|
measurements.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", ")));
|
||||||
System.out.println("}");
|
System.out.println("}");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -123,15 +122,20 @@ public class CalculateAverage_royvanrijn {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final int TABLE_SIZE = 1 << 19; // large enough for the contest.
|
// This is where I store the hashtable entry data in the "hot loop"
|
||||||
private static final int TABLE_MASK = (TABLE_SIZE - 1);
|
// The long[] contains the name in bytes (yeah, confusing)
|
||||||
|
// I've tried flyweight-ing, carrying all the data in a single byte[],
|
||||||
|
// where you offset type-indices: min:int,max:int,count:int,etc.
|
||||||
|
//
|
||||||
|
// The performance was just a little worse than this simple class.
|
||||||
static final class Entry {
|
static final class Entry {
|
||||||
private final long[] data;
|
|
||||||
private int min, max, count, length;
|
|
||||||
private long sum;
|
|
||||||
|
|
||||||
Entry(final long[] data, int length, int temp) {
|
private int min, max, count;
|
||||||
|
private byte length;
|
||||||
|
private long sum;
|
||||||
|
private final long[] data;
|
||||||
|
|
||||||
|
Entry(final long[] data, byte length, int temp) {
|
||||||
this.data = data;
|
this.data = data;
|
||||||
this.length = length;
|
this.length = length;
|
||||||
this.min = temp;
|
this.min = temp;
|
||||||
@ -164,127 +168,161 @@ public class CalculateAverage_royvanrijn {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
// Only parse the String at the final end, when we have only the needed entries left that we need to output:
|
||||||
* Delay String creation until the end:
|
private static String turnLongArrayIntoString(final long[] data, final int length) {
|
||||||
* @param data
|
// Create our target byte[]
|
||||||
* @param length
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
private static String extractedCityFromLongArray(final long[] data, final int length) {
|
|
||||||
// Initiate as late as possible:
|
|
||||||
final byte[] bytes = new byte[length];
|
final byte[] bytes = new byte[length];
|
||||||
|
// The power of magic allows us to just copy the memory in there.
|
||||||
UNSAFE.copyMemory(data, Unsafe.ARRAY_LONG_BASE_OFFSET, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
|
UNSAFE.copyMemory(data, Unsafe.ARRAY_LONG_BASE_OFFSET, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
|
||||||
|
// And construct a String()
|
||||||
return new String(bytes, StandardCharsets.UTF_8);
|
return new String(bytes, StandardCharsets.UTF_8);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Entry createNewEntry(final long[] buffer, final int lengthLongs, final int lengthBytes, final int temp) {
|
private static Entry createNewEntry(final long fromAddress, final int lengthLongs, final byte lengthBytes, final int temp) {
|
||||||
|
// Make a copy of our working buffer, store this in a new Entry:
|
||||||
final long[] bufferCopy = new long[lengthLongs];
|
final long[] bufferCopy = new long[lengthLongs];
|
||||||
System.arraycopy(buffer, 0, bufferCopy, 0, lengthLongs);
|
// Just copy everything over, bytes into the long[]
|
||||||
|
UNSAFE.copyMemory(null, fromAddress, bufferCopy, Unsafe.ARRAY_BYTE_BASE_OFFSET, lengthBytes);
|
||||||
// Add the entry:
|
|
||||||
return new Entry(bufferCopy, lengthBytes, temp);
|
return new Entry(bufferCopy, lengthBytes, temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final int TABLE_SIZE = 1 << 19;
|
||||||
|
private static final int TABLE_MASK = (TABLE_SIZE - 1);
|
||||||
|
|
||||||
private static Entry[] processMemoryArea(final long fromAddress, final long toAddress) {
|
private static Entry[] processMemoryArea(final long fromAddress, final long toAddress) {
|
||||||
|
|
||||||
final Entry[] table = new Entry[TABLE_SIZE];
|
int packedBytes;
|
||||||
final long[] buffer = new long[16];
|
|
||||||
|
|
||||||
long ptr = fromAddress;
|
|
||||||
int bufferPtr;
|
|
||||||
long hash;
|
long hash;
|
||||||
|
long ptr = fromAddress;
|
||||||
long word;
|
long word;
|
||||||
long mask;
|
long mask;
|
||||||
|
|
||||||
|
final Entry[] table = new Entry[TABLE_SIZE];
|
||||||
|
|
||||||
|
// Go from start to finish address through the bytes:
|
||||||
while (ptr < toAddress) {
|
while (ptr < toAddress) {
|
||||||
|
|
||||||
final long startAddress = ptr;
|
final long startAddress = ptr;
|
||||||
|
|
||||||
bufferPtr = 0;
|
packedBytes = 1;
|
||||||
hash = 1;
|
hash = 0;
|
||||||
word = UNSAFE.getLong(ptr);
|
word = UNSAFE.getLong(ptr);
|
||||||
mask = getDelimiterMask(word);
|
mask = getDelimiterMask(word);
|
||||||
|
|
||||||
|
// Removed writing to a buffer here, why would we, we know the address and we'll need to check there anyway.
|
||||||
while (mask == 0) {
|
while (mask == 0) {
|
||||||
buffer[bufferPtr++] = word;
|
// If the mask is zero, we have no ';'
|
||||||
|
packedBytes++;
|
||||||
|
// So we continue building the hash:
|
||||||
hash ^= word;
|
hash ^= word;
|
||||||
ptr += 8;
|
ptr += 8;
|
||||||
|
|
||||||
|
// And getting a new value and mask:
|
||||||
word = UNSAFE.getLong(ptr);
|
word = UNSAFE.getLong(ptr);
|
||||||
mask = getDelimiterMask(word);
|
mask = getDelimiterMask(word);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Found delimiter:
|
// Found delimiter:
|
||||||
final long delimiterAddress = ptr + (Long.numberOfTrailingZeros(mask) >> 3);
|
final int delimiterByte = Long.numberOfTrailingZeros(mask);
|
||||||
final long numberBits = UNSAFE.getLong(delimiterAddress + 1);
|
final long delimiterAddress = ptr + (delimiterByte >> 3);
|
||||||
|
|
||||||
// Finish the masks and hash:
|
// Finish the masks and hash:
|
||||||
word = word & ((mask >> 7) - 1);
|
final long partialWord = word & ((mask >>> 7) - 1);
|
||||||
buffer[bufferPtr++] = word;
|
hash ^= partialWord;
|
||||||
hash ^= word;
|
|
||||||
|
|
||||||
final long invNumberBits = ~numberBits;
|
// Read a long value from memory starting from the delimiter + 1, the number part:
|
||||||
final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBits & DOT_BITS);
|
final long numberBytes = UNSAFE.getLong(delimiterAddress + 1);
|
||||||
|
final long invNumberBytes = ~numberBytes;
|
||||||
|
|
||||||
// Update counter asap, lets CPU predict.
|
// Adjust our pointer
|
||||||
|
final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBytes & DOT_BITS);
|
||||||
ptr = delimiterAddress + (decimalSepPos >> 3) + 4;
|
ptr = delimiterAddress + (decimalSepPos >> 3) + 4;
|
||||||
|
|
||||||
// Awesome idea of merykitty:
|
// Calculate the final hash and index of the table:
|
||||||
final int temp = extractTemp(numberBits, invNumberBits, decimalSepPos);
|
int intHash = (int) (hash ^ (hash >> 32));
|
||||||
|
intHash = intHash ^ (intHash >> 17);
|
||||||
int intHash = (int) (hash ^ (hash >>> 33)); // offset for extra entropy
|
|
||||||
int index = intHash & TABLE_MASK;
|
int index = intHash & TABLE_MASK;
|
||||||
|
|
||||||
// Find or insert the entry:
|
// Find or insert the entry:
|
||||||
while (true) {
|
while (true) {
|
||||||
Entry tableEntry = table[index];
|
Entry tableEntry = table[index];
|
||||||
if (tableEntry == null) {
|
if (tableEntry == null) {
|
||||||
final int length = (int) (delimiterAddress - startAddress);
|
final int temp = extractTemp(decimalSepPos, invNumberBytes, numberBytes);
|
||||||
table[index] = createNewEntry(buffer, bufferPtr, length, temp);
|
// Create a new entry:
|
||||||
|
final byte length = (byte) (delimiterAddress - startAddress);
|
||||||
|
table[index] = createNewEntry(startAddress, packedBytes, length, temp);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
else if (bufferPtr == tableEntry.data.length) {
|
// Don't bother re-checking things here like hash or length.
|
||||||
if (!arrayEquals(buffer, tableEntry.data, bufferPtr)) {
|
// we'll need to check the content anyway if it's a hit, which is most times
|
||||||
index = (index + 1) & TABLE_MASK;
|
else if (memoryEqualsEntry(startAddress, tableEntry.data, partialWord, packedBytes)) {
|
||||||
continue;
|
// temperature, you're not temporary my friend
|
||||||
}
|
final int temp = extractTemp(decimalSepPos, invNumberBytes, numberBytes);
|
||||||
// No differences in array
|
// No differences, same entry:
|
||||||
tableEntry.updateWith(temp);
|
tableEntry.updateWith(temp);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// Move to the next index
|
// Move to the next in the table, linear probing:
|
||||||
index = (index + 1) & TABLE_MASK;
|
index = (index + 1) & TABLE_MASK;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return table;
|
return table;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int extractTemp(final long numberBits, final long invNumberBits, final int decimalSepPos) {
|
/*
|
||||||
|
* `___` ___ ___ _ ___` ` ___ ` _ ` _ ` _` ___
|
||||||
|
* / ` \| _ \ __| \| \ \ / /_\ | | | | | | __|
|
||||||
|
* | () | _ / __|| . |\ V / _ \| |_| |_| | ._|
|
||||||
|
* \___/|_| |___|_|\_| \_/_/ \_\___|\___/|___|
|
||||||
|
* ---------------- BETTER SOFTWARE, FASTER --
|
||||||
|
*
|
||||||
|
* https://www.openvalue.eu/
|
||||||
|
*
|
||||||
|
* Made you look.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
private static final long DOT_BITS = 0x10101000;
|
||||||
|
private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
|
||||||
|
|
||||||
|
private static int extractTemp(final int decimalSepPos, final long invNumberBits, final long numberBits) {
|
||||||
|
// Awesome idea of merykitty:
|
||||||
|
int min28 = (28 - decimalSepPos);
|
||||||
|
// Calculates the sign
|
||||||
final long signed = (invNumberBits << 59) >> 63;
|
final long signed = (invNumberBits << 59) >> 63;
|
||||||
final long minusFilter = ~(signed & 0xFF);
|
final long minusFilter = ~(signed & 0xFF);
|
||||||
final long digits = ((numberBits & minusFilter) << (28 - decimalSepPos)) & 0x0F000F0F00L;
|
// Use the pre-calculated decimal position to adjust the values
|
||||||
final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; // filter just the result
|
final long digits = ((numberBits & minusFilter) << min28) & 0x0F000F0F00L;
|
||||||
|
// Multiply by a magic (100 * 0x1000000 + 10 * 0x10000 + 1), to get the result
|
||||||
|
final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
|
||||||
|
// And perform abs()
|
||||||
final int temp = (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick
|
final int temp = (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick
|
||||||
return temp;
|
return temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static long getDelimiterMask(final long word) {
|
|
||||||
long match = word ^ SEPARATOR_PATTERN;
|
|
||||||
return (match - 0x0101010101010101L) & ~match & 0x8080808080808080L;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL;
|
private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL;
|
||||||
private static final long DOT_BITS = 0x10101000;
|
|
||||||
private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
|
// Takes a long and finds the bytes where this exact pattern is present.
|
||||||
|
// Cool bit manipulation technique: SWAR (SIMD as a Register).
|
||||||
|
private static long getDelimiterMask(final long word) {
|
||||||
|
final long match = word ^ SEPARATOR_PATTERN;
|
||||||
|
return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);
|
||||||
|
// I've put some brackets separating the first and second part, this is faster.
|
||||||
|
// Now they run simultaneous after 'match' is altered, instead of waiting on each other.
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For case multiple hashes are equal (however unlikely) check the actual key (using longs)
|
* For case multiple hashes are equal (however unlikely) check the actual key (using longs)
|
||||||
*/
|
*/
|
||||||
static boolean arrayEquals(final long[] a, final long[] b, final int length) {
|
private static boolean memoryEqualsEntry(final long startAddress, final long[] entry, final long finalBytes, final int amountLong) {
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < (amountLong - 1); i++) {
|
||||||
if (a[i] != b[i])
|
int step = i << 3; // step by 8 bytes
|
||||||
|
if (UNSAFE.getLong(startAddress + step) != entry[i])
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
// If all previous 'whole' 8-packed byte-long values are equal
|
||||||
|
// We still need to check the final bytes that don't fit.
|
||||||
|
// and we've already calculated them for the hash.
|
||||||
|
return finalBytes == entry[amountLong - 1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user