anitasv 3.8s vs 3m 19s : Improved using custom hashmap. (#672)
* Some optimizations while staying safe * bug fix not caught on tests
This commit is contained in:
parent
974ddbae60
commit
af2b5517c8
@ -25,7 +25,6 @@ import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
public class CalculateAverage_anitasv {
|
||||
@ -44,14 +43,14 @@ public class CalculateAverage_anitasv {
|
||||
.asByteBuffer();
|
||||
while (buf.hasRemaining()) {
|
||||
if (buf.get() == ch) {
|
||||
return position + buf.position() - 1;
|
||||
return position + (buf.position() - 1);
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
byte[] getRange(long start, long end) {
|
||||
return mmapMemory.asSlice(start, end - start).toArray(ValueLayout.JAVA_BYTE);
|
||||
MemorySegment getRange(long start, long end) {
|
||||
return mmapMemory.asSlice(start, end - start);
|
||||
}
|
||||
|
||||
int parseDouble(long start, long end) {
|
||||
@ -86,22 +85,122 @@ public class CalculateAverage_anitasv {
|
||||
return buf2.hashCode();
|
||||
}
|
||||
|
||||
public boolean matches(byte[] existingStation, long start, long end) {
|
||||
ByteBuffer buf1 = ByteBuffer.wrap(existingStation);
|
||||
ByteBuffer buf2 = mmapMemory.asSlice(start, end - start).asByteBuffer();
|
||||
return buf1.equals(buf2);
|
||||
public long truncate(long index) {
|
||||
return Math.min(index, mmapMemory.byteSize());
|
||||
}
|
||||
|
||||
public long getLong(long position) {
|
||||
return mmapMemory.get(ValueLayout.JAVA_LONG_UNALIGNED, position);
|
||||
}
|
||||
}
|
||||
|
||||
private record ResultRow(byte[] station, IntSummaryStatistics statistics) {
|
||||
|
||||
public String toString() {
|
||||
return STR."\{new String(station, StandardCharsets.UTF_8)} : \{statToString(statistics)}";
|
||||
}
|
||||
private record ResultRow(IntSummaryStatistics statistics, int keyLength, int next) {
|
||||
}
|
||||
|
||||
private static Map<String, IntSummaryStatistics> process(Shard shard) {
|
||||
HashMap<Integer, List<ResultRow>> result = new HashMap<>(1 << 14);
|
||||
private static class FastHashMap {
|
||||
private final byte[] keys;
|
||||
private final ResultRow[] values;
|
||||
|
||||
private final int capacityMinusOne;
|
||||
|
||||
private final MemorySegment keySegment;
|
||||
|
||||
private int next = -1;
|
||||
|
||||
private FastHashMap(int capacity) {
|
||||
this.capacityMinusOne = capacity - 1;
|
||||
this.keys = new byte[capacity << 7];
|
||||
this.keySegment = MemorySegment.ofArray(keys);
|
||||
this.values = new ResultRow[capacity];
|
||||
}
|
||||
|
||||
IntSummaryStatistics find(int hash, Shard shard, long stationStart, long stationEnd) {
|
||||
int initialIndex = hash & capacityMinusOne;
|
||||
int lookupLength = (int) (stationEnd - stationStart);
|
||||
int lookupAligned = ((lookupLength + 7) & (-8));
|
||||
int i = initialIndex;
|
||||
|
||||
lookupAligned = (int) (shard.truncate(stationStart + lookupAligned) - stationStart) - 7;
|
||||
|
||||
do {
|
||||
int keyIndex = i << 7;
|
||||
|
||||
if (keys[keyIndex] != 0 && keys[keyIndex + lookupLength] == 0) {
|
||||
|
||||
int mismatch = -1, j;
|
||||
for (j = 0; j < lookupAligned; j += 8) {
|
||||
long entryLong = keySegment.get(ValueLayout.JAVA_LONG_UNALIGNED, keyIndex + j);
|
||||
long lookupLong = shard.getLong(stationStart + j);
|
||||
if (entryLong != lookupLong) {
|
||||
int diff = Long.numberOfTrailingZeros(entryLong ^ lookupLong);
|
||||
mismatch = j + (diff >> 3);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (mismatch == -1) {
|
||||
for (; j < lookupLength; j++) {
|
||||
byte entryByte = keys[keyIndex + j];
|
||||
byte lookupByte = shard.getByte(stationStart + j);
|
||||
if (entryByte != lookupByte) {
|
||||
mismatch = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (mismatch == -1 || mismatch >= lookupLength) {
|
||||
return this.values[i].statistics;
|
||||
}
|
||||
}
|
||||
if (keys[keyIndex] == 0) {
|
||||
MemorySegment fullLookup = shard.getRange(stationStart, stationEnd);
|
||||
|
||||
keySegment.asSlice(keyIndex, lookupLength)
|
||||
.copyFrom(fullLookup);
|
||||
|
||||
keys[keyIndex + lookupLength] = 0;
|
||||
IntSummaryStatistics stats = new IntSummaryStatistics();
|
||||
ResultRow resultRow = new ResultRow(stats, lookupLength, this.next);
|
||||
this.next = i;
|
||||
this.values[i] = resultRow;
|
||||
return stats;
|
||||
}
|
||||
|
||||
if (i == capacityMinusOne) {
|
||||
i = 0;
|
||||
}
|
||||
else {
|
||||
i++;
|
||||
}
|
||||
} while (i != initialIndex);
|
||||
throw new IllegalStateException("Hash size too small");
|
||||
}
|
||||
|
||||
Iterable<Map.Entry<String, IntSummaryStatistics>> values() {
|
||||
return () -> new Iterator<>() {
|
||||
|
||||
int scan = FastHashMap.this.next;
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return scan != -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map.Entry<String, IntSummaryStatistics> next() {
|
||||
ResultRow resultRow = values[scan];
|
||||
IntSummaryStatistics stats = resultRow.statistics;
|
||||
String key = new String(keys, scan << 7, resultRow.keyLength,
|
||||
StandardCharsets.UTF_8);
|
||||
scan = resultRow.next;
|
||||
return new AbstractMap.SimpleEntry<>(key, stats);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static Iterable<Map.Entry<String, IntSummaryStatistics>> process(Shard shard) {
|
||||
FastHashMap result = new FastHashMap(1 << 14);
|
||||
|
||||
boolean skip = shard.chunkStart != 0;
|
||||
for (long position = shard.chunkStart; position < shard.chunkEnd; position++) {
|
||||
@ -116,45 +215,19 @@ public class CalculateAverage_anitasv {
|
||||
long temperatureEnd = shard.indexOf(stationEnd + 1, (byte) '\n');
|
||||
int temperature = shard.parseDouble(stationEnd + 1, temperatureEnd);
|
||||
|
||||
List<ResultRow> collisions = result.get(hash);
|
||||
if (collisions == null) {
|
||||
collisions = new ArrayList<>();
|
||||
result.put(hash, collisions);
|
||||
}
|
||||
|
||||
boolean found = false;
|
||||
for (ResultRow existing : collisions) {
|
||||
byte[] existingStation = existing.station();
|
||||
if (shard.matches(existingStation, position, stationEnd)) {
|
||||
existing.statistics.accept(temperature);
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
IntSummaryStatistics stats = new IntSummaryStatistics();
|
||||
stats.accept(temperature);
|
||||
ResultRow rr = new ResultRow(shard.getRange(position, stationEnd), stats);
|
||||
collisions.add(rr);
|
||||
}
|
||||
IntSummaryStatistics stats = result.find(hash, shard, position, stationEnd);
|
||||
stats.accept(temperature);
|
||||
position = temperatureEnd;
|
||||
}
|
||||
}
|
||||
|
||||
return result.values()
|
||||
.stream()
|
||||
.flatMap(Collection::stream)
|
||||
.map(rr -> new AbstractMap.SimpleImmutableEntry<>(
|
||||
new String(rr.station, StandardCharsets.UTF_8),
|
||||
rr.statistics))
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
return result.values();
|
||||
}
|
||||
|
||||
private static Map<String, IntSummaryStatistics> combineResults(List<Map<String, IntSummaryStatistics>> list) {
|
||||
|
||||
private static Map<String, IntSummaryStatistics> combineResults(List<Iterable<Map.Entry<String, IntSummaryStatistics>>> list) {
|
||||
Map<String, IntSummaryStatistics> output = HashMap.newHashMap(1024);
|
||||
for (Map<String, IntSummaryStatistics> map : list) {
|
||||
for (Map.Entry<String, IntSummaryStatistics> entry : map.entrySet()) {
|
||||
for (Iterable<Map.Entry<String, IntSummaryStatistics>> map : list) {
|
||||
for (Map.Entry<String, IntSummaryStatistics> entry : map) {
|
||||
output.compute(entry.getKey(), (ignore, val) -> {
|
||||
if (val == null) {
|
||||
return entry.getValue();
|
||||
|
Loading…
Reference in New Issue
Block a user