anitasv 3.8s vs 3m 19s : Improved using custom hashmap. (#672)

* Some optimizations while staying safe

* bug fix not caught on tests
This commit is contained in:
Anita SV 2024-01-31 00:41:33 -08:00 committed by GitHub
parent 974ddbae60
commit af2b5517c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,7 +25,6 @@ import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class CalculateAverage_anitasv {
@ -44,14 +43,14 @@ public class CalculateAverage_anitasv {
.asByteBuffer();
while (buf.hasRemaining()) {
if (buf.get() == ch) {
return position + buf.position() - 1;
return position + (buf.position() - 1);
}
}
return -1;
}
byte[] getRange(long start, long end) {
return mmapMemory.asSlice(start, end - start).toArray(ValueLayout.JAVA_BYTE);
MemorySegment getRange(long start, long end) {
return mmapMemory.asSlice(start, end - start);
}
int parseDouble(long start, long end) {
@ -86,22 +85,122 @@ public class CalculateAverage_anitasv {
return buf2.hashCode();
}
public boolean matches(byte[] existingStation, long start, long end) {
ByteBuffer buf1 = ByteBuffer.wrap(existingStation);
ByteBuffer buf2 = mmapMemory.asSlice(start, end - start).asByteBuffer();
return buf1.equals(buf2);
public long truncate(long index) {
return Math.min(index, mmapMemory.byteSize());
}
public long getLong(long position) {
return mmapMemory.get(ValueLayout.JAVA_LONG_UNALIGNED, position);
}
}
private record ResultRow(byte[] station, IntSummaryStatistics statistics) {
public String toString() {
return STR."\{new String(station, StandardCharsets.UTF_8)} : \{statToString(statistics)}";
}
private record ResultRow(IntSummaryStatistics statistics, int keyLength, int next) {
}
private static Map<String, IntSummaryStatistics> process(Shard shard) {
HashMap<Integer, List<ResultRow>> result = new HashMap<>(1 << 14);
private static class FastHashMap {
private final byte[] keys;
private final ResultRow[] values;
private final int capacityMinusOne;
private final MemorySegment keySegment;
private int next = -1;
private FastHashMap(int capacity) {
this.capacityMinusOne = capacity - 1;
this.keys = new byte[capacity << 7];
this.keySegment = MemorySegment.ofArray(keys);
this.values = new ResultRow[capacity];
}
IntSummaryStatistics find(int hash, Shard shard, long stationStart, long stationEnd) {
int initialIndex = hash & capacityMinusOne;
int lookupLength = (int) (stationEnd - stationStart);
int lookupAligned = ((lookupLength + 7) & (-8));
int i = initialIndex;
lookupAligned = (int) (shard.truncate(stationStart + lookupAligned) - stationStart) - 7;
do {
int keyIndex = i << 7;
if (keys[keyIndex] != 0 && keys[keyIndex + lookupLength] == 0) {
int mismatch = -1, j;
for (j = 0; j < lookupAligned; j += 8) {
long entryLong = keySegment.get(ValueLayout.JAVA_LONG_UNALIGNED, keyIndex + j);
long lookupLong = shard.getLong(stationStart + j);
if (entryLong != lookupLong) {
int diff = Long.numberOfTrailingZeros(entryLong ^ lookupLong);
mismatch = j + (diff >> 3);
break;
}
}
if (mismatch == -1) {
for (; j < lookupLength; j++) {
byte entryByte = keys[keyIndex + j];
byte lookupByte = shard.getByte(stationStart + j);
if (entryByte != lookupByte) {
mismatch = j;
break;
}
}
}
if (mismatch == -1 || mismatch >= lookupLength) {
return this.values[i].statistics;
}
}
if (keys[keyIndex] == 0) {
MemorySegment fullLookup = shard.getRange(stationStart, stationEnd);
keySegment.asSlice(keyIndex, lookupLength)
.copyFrom(fullLookup);
keys[keyIndex + lookupLength] = 0;
IntSummaryStatistics stats = new IntSummaryStatistics();
ResultRow resultRow = new ResultRow(stats, lookupLength, this.next);
this.next = i;
this.values[i] = resultRow;
return stats;
}
if (i == capacityMinusOne) {
i = 0;
}
else {
i++;
}
} while (i != initialIndex);
throw new IllegalStateException("Hash size too small");
}
Iterable<Map.Entry<String, IntSummaryStatistics>> values() {
return () -> new Iterator<>() {
int scan = FastHashMap.this.next;
@Override
public boolean hasNext() {
return scan != -1;
}
@Override
public Map.Entry<String, IntSummaryStatistics> next() {
ResultRow resultRow = values[scan];
IntSummaryStatistics stats = resultRow.statistics;
String key = new String(keys, scan << 7, resultRow.keyLength,
StandardCharsets.UTF_8);
scan = resultRow.next;
return new AbstractMap.SimpleEntry<>(key, stats);
}
};
}
}
private static Iterable<Map.Entry<String, IntSummaryStatistics>> process(Shard shard) {
FastHashMap result = new FastHashMap(1 << 14);
boolean skip = shard.chunkStart != 0;
for (long position = shard.chunkStart; position < shard.chunkEnd; position++) {
@ -116,45 +215,19 @@ public class CalculateAverage_anitasv {
long temperatureEnd = shard.indexOf(stationEnd + 1, (byte) '\n');
int temperature = shard.parseDouble(stationEnd + 1, temperatureEnd);
List<ResultRow> collisions = result.get(hash);
if (collisions == null) {
collisions = new ArrayList<>();
result.put(hash, collisions);
}
boolean found = false;
for (ResultRow existing : collisions) {
byte[] existingStation = existing.station();
if (shard.matches(existingStation, position, stationEnd)) {
existing.statistics.accept(temperature);
found = true;
break;
}
}
if (!found) {
IntSummaryStatistics stats = new IntSummaryStatistics();
stats.accept(temperature);
ResultRow rr = new ResultRow(shard.getRange(position, stationEnd), stats);
collisions.add(rr);
}
IntSummaryStatistics stats = result.find(hash, shard, position, stationEnd);
stats.accept(temperature);
position = temperatureEnd;
}
}
return result.values()
.stream()
.flatMap(Collection::stream)
.map(rr -> new AbstractMap.SimpleImmutableEntry<>(
new String(rr.station, StandardCharsets.UTF_8),
rr.statistics))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
return result.values();
}
private static Map<String, IntSummaryStatistics> combineResults(List<Map<String, IntSummaryStatistics>> list) {
private static Map<String, IntSummaryStatistics> combineResults(List<Iterable<Map.Entry<String, IntSummaryStatistics>>> list) {
Map<String, IntSummaryStatistics> output = HashMap.newHashMap(1024);
for (Map<String, IntSummaryStatistics> map : list) {
for (Map.Entry<String, IntSummaryStatistics> entry : map.entrySet()) {
for (Iterable<Map.Entry<String, IntSummaryStatistics>> map : list) {
for (Map.Entry<String, IntSummaryStatistics> entry : map) {
output.compute(entry.getKey(), (ignore, val) -> {
if (val == null) {
return entry.getValue();