diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java index 8aa1a95..4e6b255 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java @@ -24,7 +24,6 @@ import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.Future; -import java.util.stream.Collectors; public class CalculateAverage_zerninv { private static final String FILE = "./measurements.txt"; @@ -91,9 +90,9 @@ public class CalculateAverage_zerninv { } private static Map calcForChunk(FileChannel channel, long begin, long end) throws IOException { - var results = new HashMap(10_000); var mbb = channel.map(FileChannel.MapMode.READ_ONLY, begin, end - begin); - int cityOffset, hashCode, temperatureOffset, temperature; + var results = new MeasurementContainer(mbb); + int cityOffset, cityNameSize, hashCode, temperatureOffset, temperature; byte b; while (mbb.hasRemaining()) { @@ -104,7 +103,7 @@ public class CalculateAverage_zerninv { } temperatureOffset = mbb.position(); - CityWrapper city = new CityWrapper(mbb, cityOffset, temperatureOffset - cityOffset - 1, hashCode); + cityNameSize = temperatureOffset - cityOffset - 1; temperature = 0; while ((b = mbb.get()) != LINE_SEPARATOR) { @@ -115,32 +114,22 @@ public class CalculateAverage_zerninv { if (mbb.get(temperatureOffset) == MINUS) { temperature *= -1; } - - var result = results.get(city); - if (result != null) { - result.addTemperature(temperature); - } - else { - results.put(city, new MeasurementAggregation().addTemperature(temperature)); - } + results.put(cityOffset, cityNameSize, hashCode, (short) temperature); } - return results.entrySet() - .stream() - .collect(Collectors.toMap(entry -> entry.getKey().toString(), Map.Entry::getValue)); + return results.toStringMap(); } private static final class MeasurementAggregation { private long sum; private int count; - private int min = Integer.MAX_VALUE; - private int max = Integer.MIN_VALUE; + private short min; + private short max; - public MeasurementAggregation addTemperature(int temperature) { - sum += temperature; - count++; - min = Math.min(temperature, min); - max = Math.max(temperature, max); - return this; + public MeasurementAggregation(long sum, int count, short min, short max) { + this.sum = sum; + this.count = count; + this.min = min; + this.max = max; } public void merge(MeasurementAggregation o) { @@ -149,8 +138,8 @@ public class CalculateAverage_zerninv { } sum += o.sum; count += o.count; - min = Math.min(min, o.min); - max = Math.max(max, o.max); + min = min < o.min ? min : o.min; + max = max > o.max ? max : o.max; } @Override @@ -159,39 +148,77 @@ public class CalculateAverage_zerninv { } } - private record CityWrapper(MappedByteBuffer mbb, int begin, int size, int hash) { + private static final class MeasurementContainer { + private static final int SIZE = 1024 * 16; - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } + private final MappedByteBuffer mbb; + private final int[] offsets = new int[SIZE]; + private final int[] sizes = new int[SIZE]; + private final int[] hashes = new int[SIZE]; - CityWrapper that = (CityWrapper) o; - if (hash != that.hash || size != that.size) { - return false; + private final long[] sums = new long[SIZE]; + private final int[] counts = new int[SIZE]; + private final short[] mins = new short[SIZE]; + private final short[] maxs = new short[SIZE]; + + private MeasurementContainer(MappedByteBuffer mbb) { + this.mbb = mbb; + Arrays.fill(mins, Short.MAX_VALUE); + Arrays.fill(maxs, Short.MIN_VALUE); + } + + public void put(int offset, int size, int hash, short value) { + int i = findIdx(offset, size, hash); + offsets[i] = offset; + sizes[i] = size; + hashes[i] = hash; + + sums[i] += value; + counts[i]++; + + if (value < mins[i]) { + mins[i] = value; } - for (int i = 0; i < size; i++) { - if (mbb.get(begin + i) != mbb.get(that.begin + i)) { + if (value > maxs[i]) { + maxs[i] = value; + } + } + + public Map toStringMap() { + var result = new HashMap(); + for (int i = 0; i < SIZE; i++) { + if (counts[i] != 0) { + var key = createString(offsets[i], sizes[i]); + result.put(key, new MeasurementAggregation(sums[i], counts[i], mins[i], maxs[i])); + } + } + return result; + } + + private int findIdx(int offset, int size, int hash) { + int i = Math.abs(hash % SIZE); + while (counts[i] != 0) { + if (hashes[i] == hash && sizes[i] == size && isEqual(i, offset)) { + break; + } + i = (i + 1) % SIZE; + } + return i; + } + + private boolean isEqual(int index, int offset) { + for (int i = 0; i < sizes[index]; i++) { + if (mbb.get(offsets[index] + i) != mbb.get(offset + i)) { return false; } } return true; } - @Override - public int hashCode() { - return hash; - } - - @Override - public String toString() { + private String createString(int offset, int size) { byte[] arr = new byte[size]; for (int i = 0; i < size; i++) { - arr[i] = mbb.get(begin + i); + arr[i] = mbb.get(offset + i); } return new String(arr); }