diff --git a/calculate_average_jparera.sh b/calculate_average_jparera.sh new file mode 100755 index 0000000..4c7a9e7 --- /dev/null +++ b/calculate_average_jparera.sh @@ -0,0 +1,19 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector -XX:-TieredCompilation" +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_jparera diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java new file mode 100644 index 0000000..1325255 --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java @@ -0,0 +1,351 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.TreeMap; +import java.util.function.Function; +import java.util.stream.Collectors; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.VectorSpecies; + +public class CalculateAverage_jparera { + private static final String FILE = "./measurements.txt"; + + private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_PREFERRED; + + private static final int BYTE_SPECIES_SIZE = BYTE_SPECIES.vectorByteSize(); + + private static final int BYTE_SPECIES_LANES = BYTE_SPECIES.length(); + + private static final ValueLayout.OfLong LONG_U_LE = ValueLayout.JAVA_LONG_UNALIGNED + .withOrder(ByteOrder.LITTLE_ENDIAN); + + public static void main(String[] args) throws IOException { + try (var fc = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { + try (var arena = Arena.ofShared()) { + var fs = fc.map(MapMode.READ_ONLY, 0, fc.size(), arena); + var map = chunks(fs) + .parallelStream() + .map(Chunk::parse) + .flatMap(List::stream) + .collect(Collectors.toMap( + Entry::key, + Function.identity(), + Entry::merge, + TreeMap::new)); + System.out.println(map); + } + } + } + + private static Collection chunks(MemorySegment ms) { + var cpus = Runtime.getRuntime().availableProcessors(); + long expectedChunkSize = Math.ceilDiv(ms.byteSize(), cpus); + var chunks = new ArrayList(); + long fileSize = ms.byteSize(); + long offset = 0; + while (offset < fileSize) { + var end = Math.min(offset + expectedChunkSize, fileSize); + while (end < fileSize && ms.get(ValueLayout.JAVA_BYTE, end++) != '\n') { + } + long len = end - offset; + chunks.add(new Chunk(ms.asSlice(offset, len))); + offset = end; + } + return chunks; + } + + private static final class Chunk { + private static final byte SEPARATOR = ';'; + + private static final byte DECIMAL_SEPARATOR = '.'; + + private static final byte LF = '\n'; + + private static final byte MINUS = '-'; + + private static final int KEY_LOG2_BYTES = 7; + + private static final int KEY_BYTES = 1 << KEY_LOG2_BYTES; + + private static final int MAP_CAPACITY = 1 << 16; + + private static final int BUCKET_MASK = MAP_CAPACITY - 1; + + private final MemorySegment segment; + + private final Entry[] entries = new Entry[MAP_CAPACITY]; + + private long offset; + + private byte current; + + private boolean hasCurrent = true; + + Chunk(MemorySegment segment) { + this.segment = segment; + } + + public List parse() { + long size = this.segment.byteSize(); + long safe = size - KEY_BYTES; + while (offset < safe) { + var e = vectorizedEntry(); + int value = vectorizedValue(); + e.add(value); + } + next(); + while (hasCurrent()) { + var e = entry(); + int value = value(); + e.add(value); + } + var output = new ArrayList(entries.length); + for (int i = 0; i < entries.length; i++) { + var e = entries[i]; + if (e != null) { + output.add(e); + } + } + return output; + } + + private Entry vectorizedEntry() { + var start = this.offset; + var first = ByteVector.fromMemorySegment(BYTE_SPECIES, this.segment, start, ByteOrder.nativeOrder()); + int equals = first.eq(SEPARATOR).firstTrue(); + int len = equals; + for (int i = BYTE_SPECIES_SIZE; equals == BYTE_SPECIES_LANES; i += BYTE_SPECIES_SIZE) { + var next = ByteVector.fromMemorySegment(BYTE_SPECIES, this.segment, start + i, ByteOrder.nativeOrder()); + equals = next.eq(SEPARATOR).firstTrue(); + len += equals; + } + this.offset = start + len + 1; + int index = hash(this.segment, start, len); + int count = 0; + while (count < BUCKET_MASK) { + index = index & BUCKET_MASK; + var e = this.entries[index]; + if (e == null) { + return this.entries[index] = new Entry(len, this.segment.asSlice(start, KEY_BYTES)); + } + else if (e.keyLength() == len && vectorizedEquals(e, first, start, len)) { + return e; + } + index++; + count++; + } + throw new IllegalStateException("Map is full!"); + } + + private Entry entry() { + long start = this.offset - 1; + int len = 0; + while (hasCurrent() && current != SEPARATOR) { + len++; + next(); + } + expect(SEPARATOR); + int index = hash(segment, start, len); + int count = 0; + while (count < BUCKET_MASK) { + index = index & BUCKET_MASK; + var e = this.entries[index]; + if (e == null) { + return this.entries[index] = new Entry(len, this.segment.asSlice(start, len)); + } + else if (e.keyLength() == len && equals(e, start, len)) { + return e; + } + index++; + count++; + } + throw new IllegalStateException("Map is full!"); + } + + private static final long MULTIPLY_ADD_DIGITS = 100 * (1L << 24) + 10 * (1L << 16) + 1; + + private int vectorizedValue() { + long dw = this.segment.get(LONG_U_LE, this.offset); + boolean negative = ((dw & 0xFF) ^ MINUS) == 0; + int zeros = Long.numberOfTrailingZeros(~dw & 0x10101000L); + dw = ((negative ? (dw & ~0xFF) : dw) << (28 - zeros)) & 0x0F000F0F00L; + int value = (int) (((dw * MULTIPLY_ADD_DIGITS) >>> 32) & 0x3FF); + this.offset += (zeros >>> 3) + 3; + return negative ? -value : value; + } + + private int value() { + int value = 0; + var negative = false; + if (consume(MINUS)) { + negative = true; + } + while (hasCurrent()) { + if ((current & 0xF0) == 0x30) { + value *= 10; + value += current - '0'; + } + else if (current != DECIMAL_SEPARATOR) { + break; + } + next(); + } + if (hasCurrent()) { + expect(LF); + } + return negative ? -value : value; + } + + private boolean vectorizedEquals(Entry entry, ByteVector okey, long offset, int len) { + var ekey = ByteVector.fromMemorySegment(BYTE_SPECIES, entry.segment(), 0, ByteOrder.nativeOrder()); + int equals = ekey.eq(okey).not().firstTrue(); + if (equals != BYTE_SPECIES_LANES) { + return equals >= len; + } + long eo = BYTE_SPECIES_SIZE; + int total = BYTE_SPECIES_LANES; + while (equals == BYTE_SPECIES_LANES & eo < KEY_BYTES) { + offset += BYTE_SPECIES_SIZE; + ekey = ByteVector.fromMemorySegment(BYTE_SPECIES, entry.segment(), eo, ByteOrder.nativeOrder()); + okey = ByteVector.fromMemorySegment(BYTE_SPECIES, segment, offset, ByteOrder.nativeOrder()); + equals = ekey.eq(okey).not().firstTrue(); + total += equals; + eo += BYTE_SPECIES_SIZE; + } + return total >= len; + } + + private boolean equals(Entry entry, long offset, int len) { + return MemorySegment.mismatch(this.segment, offset, offset + len, entry.segment(), 0, len) == -1; + } + + private static final int GOLDEN_RATIO = 0x9E3779B9; + private static final int HASH_LROTATE = 5; + + private static int hash(MemorySegment ms, long start, int len) { + int x, y; + if (len >= Integer.BYTES) { + x = ms.get(ValueLayout.JAVA_INT_UNALIGNED, start); + y = ms.get(ValueLayout.JAVA_INT_UNALIGNED, start + len - Integer.BYTES); + } + else { + x = ms.get(ValueLayout.JAVA_BYTE, start); + y = ms.get(ValueLayout.JAVA_BYTE, start + len - Byte.BYTES); + } + return (Integer.rotateLeft(x * GOLDEN_RATIO, HASH_LROTATE) ^ y) * GOLDEN_RATIO; + } + + private void expect(byte b) { + if (!consume(b)) { + throw new IllegalStateException("Unexpected token!"); + } + } + + private boolean consume(byte b) { + if (current == b) { + next(); + return true; + } + return false; + } + + private boolean hasCurrent() { + return hasCurrent; + } + + private void next() { + if (offset < segment.byteSize()) { + this.current = segment.get(ValueLayout.JAVA_BYTE, offset++); + } + else { + this.hasCurrent = false; + } + } + } + + private static final class Entry { + private final int keyLength; + + private final MemorySegment segment; + + private int min = Integer.MAX_VALUE; + + private int max = Integer.MIN_VALUE; + + private long sum; + + private int count; + + Entry(int keyLength, MemorySegment segment) { + this.keyLength = keyLength; + this.segment = segment; + } + + int keyLength() { + return keyLength; + } + + MemorySegment segment() { + return segment; + } + + public String key() { + return new String(segment.asSlice(0, keyLength).toArray(ValueLayout.JAVA_BYTE), StandardCharsets.UTF_8); + } + + public void add(int value) { + min = Math.min(min, value); + max = Math.max(max, value); + sum += value; + count++; + } + + public Entry merge(Entry o) { + min = Math.min(min, o.min); + max = Math.max(max, o.max); + sum += o.sum; + count += o.count; + return this; + } + + @Override + public String toString() { + var average = Math.round(((sum / 10.0) / count) * 10.0); + return decimal(min) + "/" + decimal(average) + "/" + decimal(max); + } + + private static String decimal(long value) { + boolean negative = value < 0; + value = Math.abs(value); + return (negative ? "-" : "") + (value / 10) + "." + (value % 10); + } + } +}