diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java index 42cf6b8..b28750f 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java @@ -25,9 +25,7 @@ import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.ArrayList; -import java.util.List; -import java.util.TreeMap; +import java.util.*; public class CalculateAverage_zerninv { private static final String FILE = "./measurements.txt"; @@ -55,10 +53,11 @@ public class CalculateAverage_zerninv { var tasks = new TaskThread[CORES]; for (int i = 0; i < tasks.length; i++) { - tasks[i] = new TaskThread(new MeasurementContainer(), (int) (fileSize / minChunkSize / CORES + 1)); + tasks[i] = new TaskThread((int) (fileSize / minChunkSize / CORES + 1)); } - var chunks = splitByChunks(segment.address(), segment.address() + fileSize, minChunkSize); + var results = new HashMap(); + var chunks = splitByChunks(segment.address(), segment.address() + fileSize, minChunkSize, results); for (int i = 0; i < chunks.size() - 1; i++) { var task = tasks[i % tasks.length]; task.addChunk(chunks.get(i), chunks.get(i + 1)); @@ -68,19 +67,9 @@ public class CalculateAverage_zerninv { task.start(); } - var results = new TreeMap(); for (var task : tasks) { task.join(); - task.measurements() - .forEach(measurement -> { - var aggr = results.get(measurement.station()); - if (aggr == null) { - results.put(measurement.station(), measurement.aggregation()); - } - else { - aggr.merge(measurement.aggregation()); - } - }); + task.collectTo(results); } var bos = new BufferedOutputStream(System.out); @@ -90,7 +79,31 @@ public class CalculateAverage_zerninv { } } - private static List splitByChunks(long address, long end, long minChunkSize) { + private static List splitByChunks(long address, long end, long minChunkSize, Map results) { + // handle last line + long offset = end - 1; + int temperature = 0; + byte b; + int multiplier = 1; + while ((b = UNSAFE.getByte(offset--)) != ';') { + if (b >= '0' && b <= '9') { + temperature += (b - '0') * multiplier; + multiplier *= 10; + } + else if (b == '-') { + temperature = -temperature; + } + } + long cityNameEnd = offset; + while (UNSAFE.getByte(offset - 1) != '\n' && offset > address) { + offset--; + } + var cityName = new byte[(int) (cityNameEnd - offset + 1)]; + UNSAFE.copyMemory(null, offset, cityName, Unsafe.ARRAY_BYTE_BASE_OFFSET, cityName.length); + results.put(new String(cityName, StandardCharsets.UTF_8), new TemperatureAggregation(temperature, 1, (short) temperature, (short) temperature)); + + // split by chunks + end = offset; List result = new ArrayList<>((int) ((end - address) / minChunkSize + 1)); result.add(address); while (address < end) { @@ -115,14 +128,11 @@ public class CalculateAverage_zerninv { this.max = max; } - public void merge(TemperatureAggregation o) { - if (o == null) { - return; - } - sum += o.sum; - count += o.count; - min = min < o.min ? min : o.min; - max = max > o.max ? max : o.max; + public void merge(long sum, int count, short min, short max) { + this.sum += sum; + this.count += count; + this.min = this.min < min ? this.min : min; + this.max = this.max > max ? this.max : max; } @Override @@ -131,9 +141,6 @@ public class CalculateAverage_zerninv { } } - private record Measurement(String station, TemperatureAggregation aggregation) { - } - private static final class MeasurementContainer { private static final int SIZE = 1 << 17; @@ -190,23 +197,26 @@ public class CalculateAverage_zerninv { UNSAFE.putShort(ptr + MAX_OFFSET, value); } - public List measurements() { - var result = new ArrayList(1000); + public void collectTo(Map results) { int count; for (int i = 0; i < SIZE; i++) { long ptr = this.address + i * ENTRY_SIZE; count = UNSAFE.getInt(ptr + COUNT_OFFSET); if (count != 0) { var station = createString(UNSAFE.getLong(ptr + ADDRESS_OFFSET), UNSAFE.getByte(ptr + SIZE_OFFSET)); - var measurements = new TemperatureAggregation( - UNSAFE.getLong(ptr + SUM_OFFSET), - count, - UNSAFE.getShort(ptr + MIN_OFFSET), - UNSAFE.getShort(ptr + MAX_OFFSET)); - result.add(new Measurement(station, measurements)); + var result = results.get(station); + if (result == null) { + results.put(station, new TemperatureAggregation( + UNSAFE.getLong(ptr + SUM_OFFSET), + count, + UNSAFE.getShort(ptr + MIN_OFFSET), + UNSAFE.getShort(ptr + MAX_OFFSET))); + } + else { + result.merge(UNSAFE.getLong(ptr + SUM_OFFSET), count, UNSAFE.getShort(ptr + MIN_OFFSET), UNSAFE.getShort(ptr + MAX_OFFSET)); + } } } - return result; } private boolean isEqual(long address, long address2, int size) { @@ -237,14 +247,25 @@ public class CalculateAverage_zerninv { private static final int BYTE_MASK = 0xff; private static final int ZERO = '0'; - private static final byte DELIMITER = ';'; + private static final long DELIMITER_MASK = 0x3b3b3b3b3b3b3b3bL; + private static final long[] SIGNIFICANT_BYTES_MASK = { + 0, + 0xff, + 0xffff, + 0xffffff, + 0xffffffffL, + 0xffffffffffL, + 0xffffffffffffL, + 0xffffffffffffffL, + 0xffffffffffffffffL + }; private final MeasurementContainer container; private final List begins; private final List ends; - private TaskThread(MeasurementContainer container, int chunks) { - this.container = container; + private TaskThread(int chunks) { + this.container = new MeasurementContainer(); this.begins = new ArrayList<>(chunks); this.ends = new ArrayList<>(chunks); } @@ -261,26 +282,33 @@ public class CalculateAverage_zerninv { } } - public List measurements() { - return container.measurements(); - } - private void calcForChunk(long offset, long end) { - long cityOffset, lastBytes; - int hashCode, temperature, word; - byte cityNameSize, b; + long cityOffset, lastBytes, city, masked, hashCode; + int temperature, word, delimiterIdx; + byte cityNameSize; while (offset < end) { cityOffset = offset; lastBytes = 0; hashCode = 0; - while ((b = UNSAFE.getByte(offset++)) != DELIMITER) { - hashCode += hashCode * 31 + b; - lastBytes = (lastBytes << 8) | b; - } - cityNameSize = (byte) (offset - cityOffset - 1); + delimiterIdx = 8; - word = UNSAFE.getInt(offset); + while (delimiterIdx == 8) { + city = UNSAFE.getLong(offset); + masked = city ^ DELIMITER_MASK; + masked = (masked - 0x0101010101010101L) & ~masked & 0x8080808080808080L; + delimiterIdx = Long.numberOfTrailingZeros(masked) >>> 3; + if (delimiterIdx == 0) { + break; + } + offset += delimiterIdx; + lastBytes = city & SIGNIFICANT_BYTES_MASK[delimiterIdx]; + hashCode = ((hashCode >>> 5) ^ lastBytes) * 0x517cc1b727220a95L; + } + + cityNameSize = (byte) (offset - cityOffset); + + word = UNSAFE.getInt(++offset); offset += 4; if ((word & TWO_NEGATIVE_DIGITS_MASK) == TWO_NEGATIVE_DIGITS_MASK) { @@ -300,8 +328,12 @@ public class CalculateAverage_zerninv { temperature = ZERO * 111 - ((word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK)); } offset++; - container.put(cityOffset, cityNameSize, hashCode, lastBytes, (short) temperature); + container.put(cityOffset, cityNameSize, Long.hashCode(hashCode), lastBytes, (short) temperature); } } + + public void collectTo(Map results) { + container.collectTo(results); + } } }