anitasv 3.8s vs 3m 19s : Improved using custom hashmap. (#672)
* Some optimizations while staying safe * bug fix not caught on tests
This commit is contained in:
parent
974ddbae60
commit
af2b5517c8
@ -25,7 +25,6 @@ import java.nio.charset.StandardCharsets;
|
|||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.StandardOpenOption;
|
import java.nio.file.StandardOpenOption;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import java.util.stream.IntStream;
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
public class CalculateAverage_anitasv {
|
public class CalculateAverage_anitasv {
|
||||||
@ -44,14 +43,14 @@ public class CalculateAverage_anitasv {
|
|||||||
.asByteBuffer();
|
.asByteBuffer();
|
||||||
while (buf.hasRemaining()) {
|
while (buf.hasRemaining()) {
|
||||||
if (buf.get() == ch) {
|
if (buf.get() == ch) {
|
||||||
return position + buf.position() - 1;
|
return position + (buf.position() - 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
byte[] getRange(long start, long end) {
|
MemorySegment getRange(long start, long end) {
|
||||||
return mmapMemory.asSlice(start, end - start).toArray(ValueLayout.JAVA_BYTE);
|
return mmapMemory.asSlice(start, end - start);
|
||||||
}
|
}
|
||||||
|
|
||||||
int parseDouble(long start, long end) {
|
int parseDouble(long start, long end) {
|
||||||
@ -86,22 +85,122 @@ public class CalculateAverage_anitasv {
|
|||||||
return buf2.hashCode();
|
return buf2.hashCode();
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean matches(byte[] existingStation, long start, long end) {
|
public long truncate(long index) {
|
||||||
ByteBuffer buf1 = ByteBuffer.wrap(existingStation);
|
return Math.min(index, mmapMemory.byteSize());
|
||||||
ByteBuffer buf2 = mmapMemory.asSlice(start, end - start).asByteBuffer();
|
}
|
||||||
return buf1.equals(buf2);
|
|
||||||
|
public long getLong(long position) {
|
||||||
|
return mmapMemory.get(ValueLayout.JAVA_LONG_UNALIGNED, position);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private record ResultRow(byte[] station, IntSummaryStatistics statistics) {
|
private record ResultRow(IntSummaryStatistics statistics, int keyLength, int next) {
|
||||||
|
|
||||||
public String toString() {
|
|
||||||
return STR."\{new String(station, StandardCharsets.UTF_8)} : \{statToString(statistics)}";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, IntSummaryStatistics> process(Shard shard) {
|
private static class FastHashMap {
|
||||||
HashMap<Integer, List<ResultRow>> result = new HashMap<>(1 << 14);
|
private final byte[] keys;
|
||||||
|
private final ResultRow[] values;
|
||||||
|
|
||||||
|
private final int capacityMinusOne;
|
||||||
|
|
||||||
|
private final MemorySegment keySegment;
|
||||||
|
|
||||||
|
private int next = -1;
|
||||||
|
|
||||||
|
private FastHashMap(int capacity) {
|
||||||
|
this.capacityMinusOne = capacity - 1;
|
||||||
|
this.keys = new byte[capacity << 7];
|
||||||
|
this.keySegment = MemorySegment.ofArray(keys);
|
||||||
|
this.values = new ResultRow[capacity];
|
||||||
|
}
|
||||||
|
|
||||||
|
IntSummaryStatistics find(int hash, Shard shard, long stationStart, long stationEnd) {
|
||||||
|
int initialIndex = hash & capacityMinusOne;
|
||||||
|
int lookupLength = (int) (stationEnd - stationStart);
|
||||||
|
int lookupAligned = ((lookupLength + 7) & (-8));
|
||||||
|
int i = initialIndex;
|
||||||
|
|
||||||
|
lookupAligned = (int) (shard.truncate(stationStart + lookupAligned) - stationStart) - 7;
|
||||||
|
|
||||||
|
do {
|
||||||
|
int keyIndex = i << 7;
|
||||||
|
|
||||||
|
if (keys[keyIndex] != 0 && keys[keyIndex + lookupLength] == 0) {
|
||||||
|
|
||||||
|
int mismatch = -1, j;
|
||||||
|
for (j = 0; j < lookupAligned; j += 8) {
|
||||||
|
long entryLong = keySegment.get(ValueLayout.JAVA_LONG_UNALIGNED, keyIndex + j);
|
||||||
|
long lookupLong = shard.getLong(stationStart + j);
|
||||||
|
if (entryLong != lookupLong) {
|
||||||
|
int diff = Long.numberOfTrailingZeros(entryLong ^ lookupLong);
|
||||||
|
mismatch = j + (diff >> 3);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (mismatch == -1) {
|
||||||
|
for (; j < lookupLength; j++) {
|
||||||
|
byte entryByte = keys[keyIndex + j];
|
||||||
|
byte lookupByte = shard.getByte(stationStart + j);
|
||||||
|
if (entryByte != lookupByte) {
|
||||||
|
mismatch = j;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (mismatch == -1 || mismatch >= lookupLength) {
|
||||||
|
return this.values[i].statistics;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (keys[keyIndex] == 0) {
|
||||||
|
MemorySegment fullLookup = shard.getRange(stationStart, stationEnd);
|
||||||
|
|
||||||
|
keySegment.asSlice(keyIndex, lookupLength)
|
||||||
|
.copyFrom(fullLookup);
|
||||||
|
|
||||||
|
keys[keyIndex + lookupLength] = 0;
|
||||||
|
IntSummaryStatistics stats = new IntSummaryStatistics();
|
||||||
|
ResultRow resultRow = new ResultRow(stats, lookupLength, this.next);
|
||||||
|
this.next = i;
|
||||||
|
this.values[i] = resultRow;
|
||||||
|
return stats;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == capacityMinusOne) {
|
||||||
|
i = 0;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
} while (i != initialIndex);
|
||||||
|
throw new IllegalStateException("Hash size too small");
|
||||||
|
}
|
||||||
|
|
||||||
|
Iterable<Map.Entry<String, IntSummaryStatistics>> values() {
|
||||||
|
return () -> new Iterator<>() {
|
||||||
|
|
||||||
|
int scan = FastHashMap.this.next;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return scan != -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map.Entry<String, IntSummaryStatistics> next() {
|
||||||
|
ResultRow resultRow = values[scan];
|
||||||
|
IntSummaryStatistics stats = resultRow.statistics;
|
||||||
|
String key = new String(keys, scan << 7, resultRow.keyLength,
|
||||||
|
StandardCharsets.UTF_8);
|
||||||
|
scan = resultRow.next;
|
||||||
|
return new AbstractMap.SimpleEntry<>(key, stats);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Iterable<Map.Entry<String, IntSummaryStatistics>> process(Shard shard) {
|
||||||
|
FastHashMap result = new FastHashMap(1 << 14);
|
||||||
|
|
||||||
boolean skip = shard.chunkStart != 0;
|
boolean skip = shard.chunkStart != 0;
|
||||||
for (long position = shard.chunkStart; position < shard.chunkEnd; position++) {
|
for (long position = shard.chunkStart; position < shard.chunkEnd; position++) {
|
||||||
@ -116,45 +215,19 @@ public class CalculateAverage_anitasv {
|
|||||||
long temperatureEnd = shard.indexOf(stationEnd + 1, (byte) '\n');
|
long temperatureEnd = shard.indexOf(stationEnd + 1, (byte) '\n');
|
||||||
int temperature = shard.parseDouble(stationEnd + 1, temperatureEnd);
|
int temperature = shard.parseDouble(stationEnd + 1, temperatureEnd);
|
||||||
|
|
||||||
List<ResultRow> collisions = result.get(hash);
|
IntSummaryStatistics stats = result.find(hash, shard, position, stationEnd);
|
||||||
if (collisions == null) {
|
stats.accept(temperature);
|
||||||
collisions = new ArrayList<>();
|
|
||||||
result.put(hash, collisions);
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean found = false;
|
|
||||||
for (ResultRow existing : collisions) {
|
|
||||||
byte[] existingStation = existing.station();
|
|
||||||
if (shard.matches(existingStation, position, stationEnd)) {
|
|
||||||
existing.statistics.accept(temperature);
|
|
||||||
found = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!found) {
|
|
||||||
IntSummaryStatistics stats = new IntSummaryStatistics();
|
|
||||||
stats.accept(temperature);
|
|
||||||
ResultRow rr = new ResultRow(shard.getRange(position, stationEnd), stats);
|
|
||||||
collisions.add(rr);
|
|
||||||
}
|
|
||||||
position = temperatureEnd;
|
position = temperatureEnd;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result.values()
|
return result.values();
|
||||||
.stream()
|
|
||||||
.flatMap(Collection::stream)
|
|
||||||
.map(rr -> new AbstractMap.SimpleImmutableEntry<>(
|
|
||||||
new String(rr.station, StandardCharsets.UTF_8),
|
|
||||||
rr.statistics))
|
|
||||||
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, IntSummaryStatistics> combineResults(List<Map<String, IntSummaryStatistics>> list) {
|
private static Map<String, IntSummaryStatistics> combineResults(List<Iterable<Map.Entry<String, IntSummaryStatistics>>> list) {
|
||||||
|
|
||||||
Map<String, IntSummaryStatistics> output = HashMap.newHashMap(1024);
|
Map<String, IntSummaryStatistics> output = HashMap.newHashMap(1024);
|
||||||
for (Map<String, IntSummaryStatistics> map : list) {
|
for (Iterable<Map.Entry<String, IntSummaryStatistics>> map : list) {
|
||||||
for (Map.Entry<String, IntSummaryStatistics> entry : map.entrySet()) {
|
for (Map.Entry<String, IntSummaryStatistics> entry : map) {
|
||||||
output.compute(entry.getKey(), (ignore, val) -> {
|
output.compute(entry.getKey(), (ignore, val) -> {
|
||||||
if (val == null) {
|
if (val == null) {
|
||||||
return entry.getValue();
|
return entry.getValue();
|
||||||
|
Loading…
Reference in New Issue
Block a user