anestoruk second attempt (#625)

* initial implementation

* few improvements and a cleanup (down to ~12s)

* use array instead of hashmap for collecting partial results
This commit is contained in:
Andrzej Nestoruk 2024-01-28 22:59:04 +01:00 committed by GitHub
parent 9282fb7b0a
commit ff35a4628b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,9 +22,7 @@ import java.nio.channels.FileChannel;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -69,14 +67,14 @@ public class CalculateAverage_anestoruk {
TreeMap<String, Record> result = new TreeMap<>(); TreeMap<String, Record> result = new TreeMap<>();
try (ExecutorService executor = Executors.newFixedThreadPool(cpus)) { try (ExecutorService executor = Executors.newFixedThreadPool(cpus)) {
List<CompletableFuture<Map<ByteWrapper, Record>>> futures = new ArrayList<>(); List<CompletableFuture<Record[]>> futures = new ArrayList<>();
for (SegmentRange range : rangeList) { for (SegmentRange range : rangeList) {
futures.add(supplyAsync(() -> process(range, segment), executor)); futures.add(supplyAsync(() -> process(range, segment), executor));
} }
for (CompletableFuture<Map<ByteWrapper, Record>> future : futures) { for (CompletableFuture<Record[]> future : futures) {
try { try {
Map<ByteWrapper, Record> partialResult = future.get(); Record[] partialResult = future.get();
combine(result, partialResult); mergeResult(result, partialResult);
} }
catch (InterruptedException | ExecutionException ex) { catch (InterruptedException | ExecutionException ex) {
throw new RuntimeException(ex); throw new RuntimeException(ex);
@ -87,20 +85,19 @@ public class CalculateAverage_anestoruk {
System.out.println(result); System.out.println(result);
} }
private static Map<ByteWrapper, Record> process(SegmentRange range, MemorySegment segment) { private static Record[] process(SegmentRange range, MemorySegment segment) {
Map<ByteWrapper, Record> partialResult = new HashMap<>(1_000); Record[] records = new Record[1024 * 100];
byte[] buffer = new byte[100]; byte[] cityBuffer = new byte[100];
long offset = range.startOffset; long offset = range.startOffset;
byte b; byte b;
while (offset < range.endOffset) { while (offset < range.endOffset) {
int cityIdx = 0; int cityLength = 0;
int hash = 0;
while ((b = segment.get(JAVA_BYTE, offset++)) != ';') { while ((b = segment.get(JAVA_BYTE, offset++)) != ';') {
buffer[cityIdx++] = b; cityBuffer[cityLength++] = b;
hash = hash * 31 + b;
} }
byte[] city = new byte[cityIdx]; hash = Math.abs(hash);
System.arraycopy(buffer, 0, city, 0, cityIdx);
ByteWrapper cityWrapper = new ByteWrapper(city);
int value = 0; int value = 0;
boolean negative; boolean negative;
if ((b = segment.get(JAVA_BYTE, offset++)) == '-') { if ((b = segment.get(JAVA_BYTE, offset++)) == '-') {
@ -116,45 +113,77 @@ public class CalculateAverage_anestoruk {
} }
} }
int temperature = negative ? -value : value; int temperature = negative ? -value : value;
byte[] city = new byte[cityLength];
partialResult.compute(cityWrapper, (_, record) -> update(record, temperature)); System.arraycopy(cityBuffer, 0, city, 0, cityLength);
addResult(records, hash, city, temperature);
}
return records;
}
private static void addResult(Record[] records, int hash, byte[] city, int temperature) {
int idx = hash % records.length;
Record record;
while ((record = records[idx]) != null) {
if (record.hash == hash && Arrays.equals(record.city, city)) {
record.add(temperature);
return;
}
idx = (idx + 1) % records.length;
}
records[idx] = new Record(hash, city, temperature);
}
private static void mergeResult(TreeMap<String, Record> result, Record[] partialResult) {
for (Record partialRecord : partialResult) {
if (partialRecord == null) {
continue;
}
String cityName = new String(partialRecord.city, UTF_8);
result.compute(cityName, (_, record) -> {
if (record == null) {
return partialRecord;
}
record.merge(partialRecord);
return record;
});
} }
return partialResult;
} }
private record SegmentRange(long startOffset, long endOffset) { private record SegmentRange(long startOffset, long endOffset) {
} }
private record ByteWrapper(byte[] bytes) {
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ByteWrapper that = (ByteWrapper) o;
return Arrays.equals(bytes, that.bytes);
}
@Override
public int hashCode() {
return Arrays.hashCode(bytes);
}
}
private static class Record { private static class Record {
private final int hash;
private final byte[] city;
private int min; private int min;
private int max; private int max;
private long sum; private long sum;
private int count; private int count;
public Record(int temperature) { public Record(int hash, byte[] city, int temperature) {
this.hash = hash;
this.city = city;
this.min = temperature; this.min = temperature;
this.max = temperature; this.max = temperature;
this.sum = temperature; this.sum = temperature;
this.count = 1; this.count = 1;
} }
public void add(int temperature) {
min = min(min, temperature);
max = max(max, temperature);
sum += temperature;
count++;
}
public void merge(Record other) {
min = min(min, other.min);
max = max(max, other.max);
sum += other.sum;
count += other.count;
}
@Override @Override
public String toString() { public String toString() {
return "%.1f/%.1f/%.1f".formatted( return "%.1f/%.1f/%.1f".formatted(
@ -163,31 +192,4 @@ public class CalculateAverage_anestoruk {
(max / 10.0)); (max / 10.0));
} }
} }
private static Record update(Record record, int temperature) {
if (record == null) {
return new Record(temperature);
}
record.min = min(record.min, temperature);
record.max = max(record.max, temperature);
record.sum += temperature;
record.count++;
return record;
}
private static void combine(TreeMap<String, Record> result, Map<ByteWrapper, Record> partialResult) {
partialResult.forEach((wrapper, partialRecord) -> {
String city = new String(wrapper.bytes, UTF_8);
result.compute(city, (_, record) -> {
if (record == null) {
return partialRecord;
}
record.min = min(record.min, partialRecord.min);
record.max = max(record.max, partialRecord.max);
record.sum += partialRecord.sum;
record.count += partialRecord.count;
return record;
});
});
}
} }