From af2b5517c894347d42e8382b4b7559bdd9a7d337 Mon Sep 17 00:00:00 2001 From: Anita SV Date: Wed, 31 Jan 2024 00:41:33 -0800 Subject: [PATCH] anitasv 3.8s vs 3m 19s : Improved using custom hashmap. (#672) * Some optimizations while staying safe * bug fix not caught on tests --- .../onebrc/CalculateAverage_anitasv.java | 167 +++++++++++++----- 1 file changed, 120 insertions(+), 47 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_anitasv.java b/src/main/java/dev/morling/onebrc/CalculateAverage_anitasv.java index c15250d..7d3d6af 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_anitasv.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_anitasv.java @@ -25,7 +25,6 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.*; -import java.util.stream.Collectors; import java.util.stream.IntStream; public class CalculateAverage_anitasv { @@ -44,14 +43,14 @@ public class CalculateAverage_anitasv { .asByteBuffer(); while (buf.hasRemaining()) { if (buf.get() == ch) { - return position + buf.position() - 1; + return position + (buf.position() - 1); } } return -1; } - byte[] getRange(long start, long end) { - return mmapMemory.asSlice(start, end - start).toArray(ValueLayout.JAVA_BYTE); + MemorySegment getRange(long start, long end) { + return mmapMemory.asSlice(start, end - start); } int parseDouble(long start, long end) { @@ -86,22 +85,122 @@ public class CalculateAverage_anitasv { return buf2.hashCode(); } - public boolean matches(byte[] existingStation, long start, long end) { - ByteBuffer buf1 = ByteBuffer.wrap(existingStation); - ByteBuffer buf2 = mmapMemory.asSlice(start, end - start).asByteBuffer(); - return buf1.equals(buf2); + public long truncate(long index) { + return Math.min(index, mmapMemory.byteSize()); + } + + public long getLong(long position) { + return mmapMemory.get(ValueLayout.JAVA_LONG_UNALIGNED, position); } } - private record ResultRow(byte[] station, IntSummaryStatistics statistics) { - - public String toString() { - return STR."\{new String(station, StandardCharsets.UTF_8)} : \{statToString(statistics)}"; - } + private record ResultRow(IntSummaryStatistics statistics, int keyLength, int next) { } - private static Map process(Shard shard) { - HashMap> result = new HashMap<>(1 << 14); + private static class FastHashMap { + private final byte[] keys; + private final ResultRow[] values; + + private final int capacityMinusOne; + + private final MemorySegment keySegment; + + private int next = -1; + + private FastHashMap(int capacity) { + this.capacityMinusOne = capacity - 1; + this.keys = new byte[capacity << 7]; + this.keySegment = MemorySegment.ofArray(keys); + this.values = new ResultRow[capacity]; + } + + IntSummaryStatistics find(int hash, Shard shard, long stationStart, long stationEnd) { + int initialIndex = hash & capacityMinusOne; + int lookupLength = (int) (stationEnd - stationStart); + int lookupAligned = ((lookupLength + 7) & (-8)); + int i = initialIndex; + + lookupAligned = (int) (shard.truncate(stationStart + lookupAligned) - stationStart) - 7; + + do { + int keyIndex = i << 7; + + if (keys[keyIndex] != 0 && keys[keyIndex + lookupLength] == 0) { + + int mismatch = -1, j; + for (j = 0; j < lookupAligned; j += 8) { + long entryLong = keySegment.get(ValueLayout.JAVA_LONG_UNALIGNED, keyIndex + j); + long lookupLong = shard.getLong(stationStart + j); + if (entryLong != lookupLong) { + int diff = Long.numberOfTrailingZeros(entryLong ^ lookupLong); + mismatch = j + (diff >> 3); + break; + } + } + if (mismatch == -1) { + for (; j < lookupLength; j++) { + byte entryByte = keys[keyIndex + j]; + byte lookupByte = shard.getByte(stationStart + j); + if (entryByte != lookupByte) { + mismatch = j; + break; + } + } + } + if (mismatch == -1 || mismatch >= lookupLength) { + return this.values[i].statistics; + } + } + if (keys[keyIndex] == 0) { + MemorySegment fullLookup = shard.getRange(stationStart, stationEnd); + + keySegment.asSlice(keyIndex, lookupLength) + .copyFrom(fullLookup); + + keys[keyIndex + lookupLength] = 0; + IntSummaryStatistics stats = new IntSummaryStatistics(); + ResultRow resultRow = new ResultRow(stats, lookupLength, this.next); + this.next = i; + this.values[i] = resultRow; + return stats; + } + + if (i == capacityMinusOne) { + i = 0; + } + else { + i++; + } + } while (i != initialIndex); + throw new IllegalStateException("Hash size too small"); + } + + Iterable> values() { + return () -> new Iterator<>() { + + int scan = FastHashMap.this.next; + + @Override + public boolean hasNext() { + return scan != -1; + } + + @Override + public Map.Entry next() { + ResultRow resultRow = values[scan]; + IntSummaryStatistics stats = resultRow.statistics; + String key = new String(keys, scan << 7, resultRow.keyLength, + StandardCharsets.UTF_8); + scan = resultRow.next; + return new AbstractMap.SimpleEntry<>(key, stats); + } + }; + } + + } + + private static Iterable> process(Shard shard) { + FastHashMap result = new FastHashMap(1 << 14); boolean skip = shard.chunkStart != 0; for (long position = shard.chunkStart; position < shard.chunkEnd; position++) { @@ -116,45 +215,19 @@ public class CalculateAverage_anitasv { long temperatureEnd = shard.indexOf(stationEnd + 1, (byte) '\n'); int temperature = shard.parseDouble(stationEnd + 1, temperatureEnd); - List collisions = result.get(hash); - if (collisions == null) { - collisions = new ArrayList<>(); - result.put(hash, collisions); - } - - boolean found = false; - for (ResultRow existing : collisions) { - byte[] existingStation = existing.station(); - if (shard.matches(existingStation, position, stationEnd)) { - existing.statistics.accept(temperature); - found = true; - break; - } - } - if (!found) { - IntSummaryStatistics stats = new IntSummaryStatistics(); - stats.accept(temperature); - ResultRow rr = new ResultRow(shard.getRange(position, stationEnd), stats); - collisions.add(rr); - } + IntSummaryStatistics stats = result.find(hash, shard, position, stationEnd); + stats.accept(temperature); position = temperatureEnd; } } - return result.values() - .stream() - .flatMap(Collection::stream) - .map(rr -> new AbstractMap.SimpleImmutableEntry<>( - new String(rr.station, StandardCharsets.UTF_8), - rr.statistics)) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + return result.values(); } - private static Map combineResults(List> list) { - + private static Map combineResults(List>> list) { Map output = HashMap.newHashMap(1024); - for (Map map : list) { - for (Map.Entry entry : map.entrySet()) { + for (Iterable> map : list) { + for (Map.Entry entry : map) { output.compute(entry.getKey(), (ignore, val) -> { if (val == null) { return entry.getValue();