diff --git a/calculate_average_3j5a.sh b/calculate_average_3j5a.sh new file mode 100755 index 0000000..b4a4277 --- /dev/null +++ b/calculate_average_3j5a.sh @@ -0,0 +1,19 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +JAVA_OPTS="--add-opens=java.base/jdk.internal.util=ALL-UNNAMED" +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_3j5a diff --git a/prepare_3j5a.sh b/prepare_3j5a.sh new file mode 100755 index 0000000..06b81c4 --- /dev/null +++ b/prepare_3j5a.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Uncomment below to use sdk +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_3j5a.java b/src/main/java/dev/morling/onebrc/CalculateAverage_3j5a.java new file mode 100644 index 0000000..178cfac --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_3j5a.java @@ -0,0 +1,277 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import java.io.IOException; +import java.io.RandomAccessFile; +import java.lang.invoke.MethodHandle; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; + +import static java.lang.Class.forName; +import static java.lang.System.out; +import static java.lang.invoke.MethodHandles.lookup; +import static java.util.Comparator.comparing; + +public class CalculateAverage_3j5a { + + private static final String FILE = "./measurements.txt"; + + public static void main(String[] args) throws IOException { + try (RandomAccessFile measurementsFile = new RandomAccessFile(FILE, "r")) { + var slices = slice(measurementsFile); + var measurementsChannel = measurementsFile.getChannel(); + slices.stream().parallel().map(slice -> { + MappedByteBuffer measurementsSlice = map(slice, measurementsChannel); + var measurementBuffer = new byte[rules.maxMeasurementLength]; + var measurements = HashMap. newHashMap(rules.uniqueStationsCount); + while (measurementsSlice.hasRemaining()) { + var a = nextStationMeasurement(measurementBuffer, measurementsSlice); + var stats = measurements.get(a.station); + if (stats == null) { + a.station.detachFromMeasurementBuffer(); + stats = new StationMeasurementStatistics(a); + measurements.put(a.station, stats); + } + else { + stats.add(a); + } + } + return measurements; + }).reduce((aslice, bslice) -> { + aslice.forEach((astation, astats) -> { + var bstats = bslice.putIfAbsent(astation, astats); + if (bstats != null) { + bstats.merge(astats); + } + }); + return bslice; + }).ifPresent(measurements -> { + var results = new StringBuilder(measurements.size() * (rules.maxStationNameLength + rules.maxStationStatisticsOutputLength)); + measurements.values().stream() + .sorted(comparing(StationMeasurementStatistics::getName)) + .forEach(stationStats -> results.append(stationStats).append(", ")); + out.println("{" + results.substring(0, results.length() - 2) + "}"); + }); + } + } + + record Rules(int minMeasurementLength, int maxStationNameLength, + int maxMeasurementLength, int maxStationStatisticsOutputLength, + int uniqueStationsCount) { + Rules() { + this(5, 100, 106, 18, 10_000); + } + } + + private static final Rules rules = new Rules(); + + record MeasurementsSlice(long start, long length) { + } + + static class Station { + + private byte[] name; + final int length; + private int hash; + + private static final MethodHandle vectorizedHashCode; + private static final int T_BYTE = 8; + + static { + try { + var arraysSupport = forName("jdk.internal.util.ArraysSupport"); + Class[] vectorizedHashCodeSignature = { Object.class, int.class, int.class, int.class, int.class }; + var vectorizedHashCodeMethod = arraysSupport.getDeclaredMethod("vectorizedHashCode", vectorizedHashCodeSignature); + vectorizedHashCode = lookup().unreflect(vectorizedHashCodeMethod); + } + catch (NoSuchMethodException | IllegalAccessException | ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + Station(byte[] name, int length) { + this.name = name; + this.length = length; + } + + public void detachFromMeasurementBuffer() { + var n = new byte[length]; + System.arraycopy(name, 0, n, 0, length); + this.name = n; + } + + @Override + public boolean equals(Object that) { + return Arrays.mismatch(this.name, 0, length, ((Station) that).name, 0, length) < 0; + } + + @Override + public int hashCode() { + if (hash == 0) { + try { + hash = (int) vectorizedHashCode.invokeExact((Object) name, 0, length, 1, T_BYTE); + } + catch (Throwable e) { + throw new RuntimeException(e); + } + } + return hash; + } + + } + + record StationMeasurement(Station station, int temperature) { + } + + private static class StationMeasurementStatistics { + + private final byte[] bname; + private String name; + private int min; + private int max; + private long sum; + private int count = 1; + + StationMeasurementStatistics(StationMeasurement stationMeasurement) { + this.bname = stationMeasurement.station.name; + this.min = stationMeasurement.temperature; + this.max = stationMeasurement.temperature; + this.sum = stationMeasurement.temperature; + } + + public String getName() { + if (name == null) { + name = new String(bname, StandardCharsets.UTF_8); + } + return name; + } + + void add(StationMeasurement measurement) { + var temperature = measurement.temperature; + update(1, temperature, temperature, temperature); + } + + void merge(StationMeasurementStatistics other) { + update(other.count, other.min, other.max, other.sum); + } + + private void update(int count, int min, int max, long sum) { + this.count += count; + if (this.min > min) { + this.min = min; + } + if (this.max < max) { + this.max = max; + } + this.sum += sum; + } + + @Override + public String toString() { + var name = getName(); + var min = this.min / 10f; + var mean = Math.round(this.sum / (float) this.count) / 10f; + var max = this.max / 10f; + return new StringBuilder(name.length() + rules.maxStationStatisticsOutputLength) + .append(name).append("=").append(min).append("/").append(mean).append("/").append(max) + .toString(); + } + } + + private static StationMeasurement nextStationMeasurement(byte[] measurement, MappedByteBuffer memoryMappedSlice) { + byte b; + int i = rules.minMeasurementLength; + memoryMappedSlice.get(measurement, 0, i); + while ((b = memoryMappedSlice.get()) != '\n') { + measurement[i] = b; + i++; + } + var zeroOffset = '0'; + int temperature = measurement[--i] - zeroOffset; + i--; // skipping dot + var base = 10; + while ((b = measurement[--i]) != ';') { + if (b == '-') { + temperature = -temperature; + } + else { + temperature = base * (b - zeroOffset) + temperature; + base *= base; + } + } + return new StationMeasurement(new Station(measurement, i), temperature); + } + + private static MappedByteBuffer map(MeasurementsSlice slice, FileChannel measurements) { + try { + return measurements.map(FileChannel.MapMode.READ_ONLY, slice.start, slice.length); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static List slice(RandomAccessFile measurements) throws IOException { + int chunks = Runtime.getRuntime().availableProcessors(); + List measurementSlices; + while ((measurementSlices = slice(measurements, chunks)) == null) { + chunks++; + } + return measurementSlices; + } + + private static List slice(RandomAccessFile measurements, int chunks) throws IOException { + long measurementsFileLength = measurements.length(); + long chunkLength = 0; + long remainder; + if (chunks < measurementsFileLength) { + chunks--; + do { + chunkLength = measurementsFileLength / ++chunks; + remainder = measurementsFileLength % chunkLength; + } while (chunkLength + remainder > Integer.MAX_VALUE); + } + if (chunkLength <= rules.maxMeasurementLength) { + return List.of(new MeasurementsSlice(0, measurementsFileLength)); + } + var measurementSlices = new ArrayList(chunks); + var sliceStart = 0L; + for (int i = 0; i < chunks - 1; i++) { + var sliceLength = chunkLength; + measurements.seek(sliceStart + sliceLength); + while (measurements.readByte() != '\n') { + measurements.seek(sliceStart + ++sliceLength); + } + sliceLength++; + if (sliceLength > Integer.MAX_VALUE) { + return null; + } + measurementSlices.add(new MeasurementsSlice(sliceStart, sliceLength)); + sliceStart = sliceStart + sliceLength; + } + var previousSlice = measurementSlices.getLast(); + var lastSliceStart = previousSlice.start + previousSlice.length; + measurementSlices.addLast(new MeasurementsSlice(lastSliceStart, measurementsFileLength - lastSliceStart)); + return measurementSlices; + } + +}