Updating Roy's submission

* Added tests for endian-calculations (had these in a different class, perhaps handy for others to see as well)

Inlined the hash function, runs locally in 2.4sec now, hopefully endian issues fix

Added equals to support any city name up to 1024 in length, don't rely on hash

* For clarity I've updated the code so endian doesn't change the hashes, easier to debug.

* Fixing bug in array check

Simple is faster

* Also spotted the diff, not just the big exception

Fixed buffer limit issue
This commit is contained in:
Roy van Rijn 2024-01-04 23:22:48 +01:00 committed by GitHub
parent acb6510a02
commit 1c74049991
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 226 additions and 79 deletions

View File

@ -15,7 +15,7 @@
# limitations under the License.
#
sdk use java 21.0.1-graal
# Added for fun, doesn't seem to be making a difference...
if [ -f "target/calculate_average_royvanrijn.jsa" ]; then
JAVA_OPTS="-XX:SharedArchiveFile=target/calculate_average_royvanrijn.jsa -Xshare:on"

View File

@ -21,10 +21,12 @@ import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.TreeMap;
import java.util.stream.Collectors;
@ -44,6 +46,7 @@ import java.util.stream.Collectors;
* Improved String skip: 3250 ms
* Segmenting files: 3150 ms (based on spullara's code)
* Not using SWAR for EOL: 2850 ms
* Inlining hash calculation: 2450 ms
*
* Best performing JVM on MacBook M2 Pro: 21.0.1-graal
* `sdk use java 21.0.1-graal`
@ -59,8 +62,8 @@ public class CalculateAverage_royvanrijn {
long sum;
public Measurement() {
this.min = 10000;
this.max = -10000;
this.min = 1000;
this.max = -1000;
}
public Measurement updateWith(int measurement) {
@ -88,8 +91,32 @@ public class CalculateAverage_royvanrijn {
}
}
public static final void main(String[] args) throws Exception {
public static void main(String[] args) throws Exception {
new CalculateAverage_royvanrijn().run();
// new CalculateAverage_royvanrijn().runTests();
}
private void testInput(final String inputString, final int start, final boolean bigEndian, final int[] expectedDelimiterAndHash, final long[] expectedCityNameLong) {
byte[] input = inputString.getBytes(StandardCharsets.UTF_8);
ByteBuffer buffer = ByteBuffer.wrap(input).order(bigEndian ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN);
int[] output = new int[2];
long[] cityName = new long[128];
findNextDelimiterAndCalculateHash(buffer, SEPARATOR_PATTERN, start, buffer.limit(), output, cityName, bigEndian);
if (!Arrays.equals(output, expectedDelimiterAndHash)) {
System.out.println("Error in delimiter or hash");
System.out.println("Expected: " + Arrays.toString(expectedDelimiterAndHash));
System.out.println("Received: " + Arrays.toString(output));
}
int amountLong = 1 + ((output[0] - start) >>> 3);
if (!Arrays.equals(cityName, 0, amountLong, expectedCityNameLong, 0, amountLong)) {
System.out.println("Error in long array");
System.out.println("Expected: " + Arrays.toString(expectedCityNameLong));
System.out.println("Received: " + Arrays.toString(cityName));
}
}
private void run() throws Exception {
@ -99,30 +126,43 @@ public class CalculateAverage_royvanrijn {
long segmentEnd = segment.end();
try (var fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ)) {
var bb = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segmentEnd - segment.start());
var buffer = new byte[64];
// Force little endian:
bb.order(ByteOrder.LITTLE_ENDIAN);
// Work with any UTF-8 city name, up to 100 in length:
var buffer = new byte[106]; // 100 + ; + -XX.X + \n
var cityNameAsLongArray = new long[13]; // 13*8=104=kenough.
var delimiterPointerAndHash = new int[2];
BitTwiddledMap measurements = new BitTwiddledMap();
// Calculate using native ordering (fastest?):
bb.order(ByteOrder.nativeOrder());
// Record the order it is and calculate accordingly:
final boolean bufferIsBigEndian = bb.order().equals(ByteOrder.BIG_ENDIAN);
MeasurementRepository measurements = new MeasurementRepository();
int startPointer;
int limit = bb.limit();
while ((startPointer = bb.position()) < limit) {
// SWAR is faster for ';'
int separatorPointer = findNextSWAR(bb, SEPARATOR_PATTERN, startPointer + 3, limit);
// SWAR method to find delimiter *and* record the cityname as long[] *and* calculate a hash:
findNextDelimiterAndCalculateHash(bb, SEPARATOR_PATTERN, startPointer, limit, delimiterPointerAndHash, cityNameAsLongArray, bufferIsBigEndian);
int delimiterPointer = delimiterPointerAndHash[0];
// Simple is faster for '\n' (just three options)
// Simple lookup is faster for '\n' (just three options)
int endPointer;
if (bb.get(separatorPointer + 4) == '\n') {
endPointer = separatorPointer + 4;
if (delimiterPointer >= limit) {
bb.position(limit); // skip to next line.
return measurements;
}
else if (bb.get(separatorPointer + 5) == '\n') {
endPointer = separatorPointer + 5;
if (bb.get(delimiterPointer + 4) == '\n') {
endPointer = delimiterPointer + 4;
}
else if (bb.get(delimiterPointer + 5) == '\n') {
endPointer = delimiterPointer + 5;
}
else {
endPointer = separatorPointer + 6;
endPointer = delimiterPointer + 6;
}
// Read the entry in a single get():
@ -130,20 +170,22 @@ public class CalculateAverage_royvanrijn {
bb.position(endPointer + 1); // skip to next line.
// Extract the measurement value (10x):
final int nameLength = separatorPointer - startPointer;
final int valueLength = endPointer - separatorPointer - 1;
final int measured = branchlessParseInt(buffer, nameLength + 1, valueLength);
measurements.getOrCreate(buffer, nameLength).updateWith(measured);
final int cityNameLength = delimiterPointer - startPointer;
final int measuredValueLength = endPointer - delimiterPointer - 1;
final int measuredValue = branchlessParseInt(buffer, cityNameLength + 1, measuredValueLength);
// Store everything in a custom hashtable:
measurements.update(buffer, cityNameLength, delimiterPointerAndHash[1], cityNameAsLongArray).updateWith(measuredValue);
}
return measurements;
}
catch (IOException e) {
throw new RuntimeException(e);
}
}).parallel().flatMap(v -> v.values.stream())
.collect(Collectors.toMap(e -> new String(e.key), BitTwiddledMap.Entry::measurement, (m1, m2) -> m1.updateWith(m2), TreeMap::new));
}).parallel()
.flatMap(v -> v.values.stream())
.collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, Measurement::updateWith, TreeMap::new));
// Seems to perform better than actually using a TreeMap:
System.out.println(results);
}
@ -151,47 +193,119 @@ public class CalculateAverage_royvanrijn {
* -------- This section contains SWAR code (SIMD Within A Register) which processes a bytebuffer as longs to find values:
*/
private static final long SEPARATOR_PATTERN = compilePattern((byte) ';');
private static final long[] PARTIAL_INDEX_MASKS = new long[]{ 0L, 255L, 65535L, 16777215L, 4294967295L, 1099511627775L, 281474976710655L, 72057594037927935L };
private int findNextSWAR(ByteBuffer bb, long pattern, int start, int limit) {
int i;
for (i = start; i <= limit - 8; i += 8) {
long word = bb.getLong(i);
int index = firstAnyPattern(word, pattern);
if (index < Long.BYTES) {
return i + index;
}
public void runTests() {
// Method used for debugging purposes, easy to make mistakes with all the bit hacking.
// These all have the same hashes:
testInput("Delft;-12.4", 0, true, new int[]{ 5, 1718384401 }, new long[]{ 499934586180L });
testInput("aDelft;-12.4", 1, true, new int[]{ 6, 1718384401 }, new long[]{ 499934586180L });
testInput("Delft;-12.4", 0, false, new int[]{ 5, 1718384401 }, new long[]{ 499934586180L });
testInput("aDelft;-12.4", 1, false, new int[]{ 6, 1718384401 }, new long[]{ 499934586180L });
testInput("Rotterdam;-12.4", 0, true, new int[]{ 9, -784321989 }, new long[]{ 7017859899421126482L, 109L });
testInput("abcdefghijklmnpoqrstuvwxyzRotterdam;-12.4", 26, true, new int[]{ 35, -784321989 }, new long[]{ 7017859899421126482L, 109L });
testInput("abcdefghijklmnpoqrstuvwxyzARotterdam;-12.4", 27, true, new int[]{ 36, -784321989 }, new long[]{ 7017859899421126482L, 109L });
testInput("Rotterdam;-12.4", 0, false, new int[]{ 9, -784321989 }, new long[]{ 7017859899421126482L, 109L });
testInput("abcdefghijklmnpoqrstuvwxyzRotterdam;-12.4", 26, false, new int[]{ 35, -784321989 }, new long[]{ 7017859899421126482L, 109L });
testInput("abcdefghijklmnpoqrstuvwxyzARotterdam;-12.4", 27, false, new int[]{ 36, -784321989 }, new long[]{ 7017859899421126482L, 109L });
// These have different hashes from the strings above:
testInput("abcdefghijklmnpoqrstuvwxyzAROtterdam;-12.4", 27, true, new int[]{ 36, -792194501 }, new long[]{ 7017859899421118290L, 109L });
testInput("abcdefghijklmnpoqrstuvwxyzAROtterdam;-12.4", 27, false, new int[]{ 36, -792194501 }, new long[]{ 7017859899421118290L, 109L });
MeasurementRepository repository = new MeasurementRepository();
// Simulate adding two entries with the same hash:
byte[] b1 = "City1;10.0".getBytes();
byte[] b2 = "City2;41.1".getBytes();
repository.update(b1, 5, 1234, new long[]{ 1234L });
repository.update(b2, 5, 1234, new long[]{ 4321L });
// And update the same record shouldn't add a third (this happened):
repository.update(b1, 5, 1234, new long[]{ 1234L });
if (repository.values.size() != 2) {
System.out.println("Error, should have two entries:");
System.out.println(repository.values);
}
// Handle remaining bytes
for (; i < limit; i++) {
if (bb.get(i) == (byte) pattern) {
return i;
}
MeasurementRepository.Entry firstInserted = repository.values.getFirst();
if (!firstInserted.cityName.equals("City1")) {
System.out.println("Error, should have correct name: " + firstInserted.cityName);
}
return limit; // delimiter not found
}
private static long compilePattern(byte value) {
/**
* Already looping the longs here, lets shoehorn in making a hash
*/
private void findNextDelimiterAndCalculateHash(final ByteBuffer bb, final long pattern, final int start, final int limit, final int[] output,
final long[] asLong, final boolean bufferBigEndian) {
int hash = 1;
int i;
int lCnt = 0;
for (i = start; i <= limit - 8; i += 8) {
long word = bb.getLong(i);
if (bufferBigEndian)
word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this
int index = firstAnyPattern(word, pattern);
if (index < Long.BYTES) {
final long partialHash = word & PARTIAL_INDEX_MASKS[index];
asLong[lCnt] = partialHash;
hash = 961 * hash + 31 * (int) (partialHash >>> 32) + (int) partialHash;
output[0] = (i + index);
output[1] = hash;
return;
}
asLong[lCnt++] = word;
hash = 961 * hash + 31 * (int) (word >>> 32) + (int) word;
}
// Handle remaining bytes
long partialHash = 0;
for (; i < limit; i++) {
byte read;
if ((read = bb.get(i)) == (byte) pattern) {
asLong[lCnt] = partialHash;
hash = 961 * hash + 31 * (int) (partialHash >>> 32) + (int) partialHash;
output[0] = i;
output[1] = hash;
return;
}
partialHash = partialHash << 8 | read;
}
output[0] = limit; // delimiter not found
output[1] = hash;
}
private static long compilePattern(final byte value) {
return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) |
((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value;
}
private static int firstAnyPattern(long word, long pattern) {
private static int firstAnyPattern(final long word, final long pattern) {
final long match = word ^ pattern;
long mask = match - 0x0101010101010101L;
mask &= ~match;
mask &= 0x8080808080808080L;
return Long.numberOfTrailingZeros(mask) >>> 3;
return Long.numberOfTrailingZeros(mask) >> 3;
}
record FileSegment(long start, long end) {
}
/** Using this way to segment the file is much prettier, from spullara */
private static List<FileSegment> getFileSegments(File file) throws IOException {
private static List<FileSegment> getFileSegments(final File file) throws IOException {
final int numberOfSegments = Runtime.getRuntime().availableProcessors();
final long fileSize = file.length();
final long segmentSize = fileSize / numberOfSegments;
final List<FileSegment> segments = new ArrayList<>();
if (segmentSize < 1000) {
segments.add(new FileSegment(0, fileSize));
return segments;
}
try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) {
for (int i = 0; i < numberOfSegments; i++) {
long segStart = i * segmentSize;
@ -205,7 +319,7 @@ public class CalculateAverage_royvanrijn {
return segments;
}
private static long findSegment(int i, int skipSegment, RandomAccessFile raf, long location, long fileSize) throws IOException {
private static long findSegment(final int i, final int skipSegment, RandomAccessFile raf, long location, final long fileSize) throws IOException {
if (i != skipSegment) {
raf.seek(location);
while (location < fileSize) {
@ -226,7 +340,7 @@ public class CalculateAverage_royvanrijn {
* @param input
* @return int value x10
*/
private static int branchlessParseInt(final byte[] input, int start, int length) {
private static int branchlessParseInt(final byte[] input, final int start, final int length) {
// 0 if positive, 1 if negative
final int negative = ~(input[start] >> 4) & 1;
// 0 if nr length is 3, 1 if length is 4
@ -258,66 +372,99 @@ public class CalculateAverage_royvanrijn {
*
* So I've written an extremely simple linear probing hashmap that should work well enough.
*/
class BitTwiddledMap {
private static final int SIZE = 16384; // A bit larger than the number of keys, needs power of two
private int[] indices = new int[SIZE]; // Hashtable is just an int[]
class MeasurementRepository {
private int size = 16384;// 16384; // Much larger than the number of cities, needs power of two
private int[] indices = new int[size]; // Hashtable is just an int[]
BitTwiddledMap() {
MeasurementRepository() {
populateEmptyIndices(indices);
}
private void populateEmptyIndices(int[] array) {
// Optimized fill with -1, fastest method:
int len = indices.length;
if (len > 0) {
indices[0] = -1;
}
int len = array.length;
array[0] = -1;
// Value of i will be [1, 2, 4, 8, 16, 32, ..., len]
for (int i = 1; i < len; i += i) {
System.arraycopy(indices, 0, indices, i, i);
System.arraycopy(array, 0, array, i, i);
}
}
private List<Entry> values = new ArrayList<>(512);
private final List<Entry> values = new ArrayList<>(512);
record Entry(int hash, byte[] key, Measurement measurement) {
record Entry(int hash, long[] cityNameAsLong, String cityName, Measurement measurement) {
@Override
public String toString() {
return new String(key) + "=" + measurement;
return cityName + "=" + measurement;
}
}
/**
* Who needs methods like add(), merge(), compute() etc, we need one, getOrCreate.
* @param key
* @return
*/
public Measurement getOrCreate(byte[] key, int length) {
int inHash;
int index = (SIZE - 1) & (inHash = hashCode(key, length));
public Measurement update(byte[] buffer, int length, int calculatedHash, long[] cityNameAsLongArray) {
final int cityNameAsLongLength = 1 + (length >>> 3); // amount of longs that captures this cityname
int hashtableIndex = (size - 1) & calculatedHash;
int valueIndex;
Entry retrievedEntry = null;
while ((valueIndex = indices[index]) != -1 && (retrievedEntry = values.get(valueIndex)).hash != inHash) {
index = (index + 1) % SIZE;
while (true) { // search for the right spot
if ((valueIndex = indices[hashtableIndex]) == -1) {
break; // Empty slot found, stop the loop
}
else {
// Non-empty slot, retrieve entry
if ((retrievedEntry = values.get(valueIndex)).hash == calculatedHash &&
arrayEquals(retrievedEntry.cityNameAsLong, cityNameAsLongArray, cityNameAsLongLength)) {
break; // Both hash and cityname match, stop the loop
}
}
// Move to the next index
hashtableIndex = (hashtableIndex + 1) % size;
}
if (valueIndex >= 0) {
return retrievedEntry.measurement;
}
// New entry, insert into table and return.
indices[index] = values.size();
// Only parse this once:
byte[] actualKey = new byte[length];
System.arraycopy(key, 0, actualKey, 0, length);
// --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!)
// Keep the already processed longs for fast equals:
long[] cityNameAsLongArrayCopy = new long[cityNameAsLongLength];
System.arraycopy(cityNameAsLongArray, 0, cityNameAsLongArrayCopy, 0, cityNameAsLongLength);
Entry toAdd = new Entry(calculatedHash, cityNameAsLongArrayCopy, new String(buffer, 0, length), new Measurement());
// Code to regrow (if we get more unique entries): (not needed/not optimized yet)
// if (values.size() > size / 2) {
// // We probably don't want this...
//
// int newSize = size << 1;
// int[] newIndices = new int[newSize];
// populateEmptyIndices(newIndices);
// for (int i = 0; i < values.size(); i++) {
// Entry e = values.get(i);
// int updatedIndex = (newSize - 1) & e.hash;
// newIndices[updatedIndex] = i;
// }
// indices = newIndices;
// size = newSize;
// }
indices[hashtableIndex] = values.size();
Entry toAdd = new Entry(inHash, actualKey, new Measurement());
values.add(toAdd);
return toAdd.measurement;
}
private static int hashCode(byte[] a, int length) {
int result = 1;
for (int i = 0; i < length; i++) {
result = 31 * result + a[i];
}
return result;
}
}
/**
* For case multiple hashes are equal (however unlikely) check the actual key (using longs)
*/
private boolean arrayEquals(final long[] a, final long[] b, final int length) {
for (int i = 0; i < length; i++) {
if (a[i] != b[i])
return false;
}
return true;
}
}