From aa0395d01bd800d82dc83b39df27c688e1f03363 Mon Sep 17 00:00:00 2001 From: Andrew Sun Date: Sun, 7 Jan 2024 04:00:58 -0500 Subject: [PATCH] Add entry by Andrew Sun --- calculate_average_asun.sh | 23 ++ .../morling/onebrc/CalculateAverage_asun.java | 329 ++++++++++++++++++ 2 files changed, 352 insertions(+) create mode 100755 calculate_average_asun.sh create mode 100644 src/main/java/dev/morling/onebrc/CalculateAverage_asun.java diff --git a/calculate_average_asun.sh b/calculate_average_asun.sh new file mode 100755 index 0000000..8ed6f7e --- /dev/null +++ b/calculate_average_asun.sh @@ -0,0 +1,23 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Uncomment below to use sdk +# source "$HOME/.sdkman/bin/sdkman-init.sh" +# sdk use java 21.0.1-graal 1>&2 + +JAVA_OPTS="--enable-preview --add-modules jdk.incubator.vector -XX:+UnlockExperimentalVMOptions -Xms500m -Xmx500m -XX:CompilationMode=high-only" +time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_asun diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java new file mode 100644 index 0000000..88a90ea --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java @@ -0,0 +1,329 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import jdk.incubator.vector.*; + +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.TreeMap; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +// based on spullara's submission + +public class CalculateAverage_asun { + private static final String FILE = "./measurements.txt"; + + private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_256; + private static final VectorSpecies INT_SPECIES = IntVector.SPECIES_256; + private static final int VECTOR_SIZE = 32; + + private static final ByteVector ASC; + static { + byte[] bytes = new byte[VECTOR_SIZE]; + for (int i = 0; i < VECTOR_SIZE; i++) { + bytes[i] = (byte) i; + } + + ASC = ByteVector.fromArray(BYTE_SPECIES, bytes, 0); + } + + public static void main(String[] args) throws IOException, ExecutionException, InterruptedException { + long start = System.currentTimeMillis(); + var filename = args.length == 0 ? FILE : args[0]; + var file = new File(filename); + + List fileSegments = getFileSegments(file); + // System.out.println(System.currentTimeMillis() - start); + var resultsMap = fileSegments.stream().map(segment -> { + var resultMap = new ByteArrayToResultMap(); + long segmentEnd = segment.end(); + try (var fileChannel = (FileChannel) Files.newByteChannel(Path.of(filename), StandardOpenOption.READ)) { + var bb = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segmentEnd - segment.start()); + var ms = MemorySegment.ofBuffer(bb); + + // Up to 100 characters for a city name + var buffer = new byte[100 + VECTOR_SIZE]; + long startLine; + long pos = 0; + long limit = ms.byteSize(); + long vectorLimit = limit - VECTOR_SIZE; + + // int[] lastHashMult = new int[]{ 7, 31, 63, 15, 255, 127, 3, 511 }; + // IntVector lastMul = IntVector.fromArray(INT_SPECIES, lastHashMult, 0); + + vector: while ((startLine = pos) < vectorLimit) { + long currentPosition = startLine; + ByteVector r; + VectorMask m; + int offset = 0; + + IntVector h = IntVector.zero(INT_SPECIES); + while (true) { + if (currentPosition >= vectorLimit) { + break vector; + } + + r = ByteVector.fromMemorySegment(BYTE_SPECIES, ms, currentPosition, ByteOrder.LITTLE_ENDIAN); + r.intoArray(buffer, offset); + offset += VECTOR_SIZE; + m = r.eq((byte) ';'); + if (m.anyTrue()) { + int firstTrue = m.firstTrue(); + currentPosition += firstTrue; + // note: target platform likely does not have AVX-512, so manipulating and using m directly is likely to be slow + ByteVector lastMask = (ByteVector) ASC.lt((byte) firstTrue).toVector(); + h = h.mul(31); + h = h.add(r.and(lastMask).reinterpretAsInts()); + break; + } + else { + currentPosition += VECTOR_SIZE; + h = h.mul(31); + h = h.add(r.reinterpretAsInts()); + } + } + + // h = h.mul(lastMul); + + int hash = h.reduceLanes(VectorOperators.ADD); + + // currentPosition now has index of semicolon + int nameLen = (int) (currentPosition - startLine); + currentPosition++; + + if (currentPosition >= limit - 8) { + break; + } + + long g = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, currentPosition); + int negative = (g & 0xff) == '-' ? -1 : 1; + + // 00101101 MINUS + // 00101110 PERIOD + // 00001101 CR + // 00001010 LF + // 00110000 0 + // 00111001 9 + + // scan for LF + long lf = ~g & 0x20202020202020L; + int tzc = Long.numberOfTrailingZeros(lf); + long bytesToLF = tzc / 8; + + int shift = 64 - tzc & 0b111000; + + long reversedDigits = Long.reverseBytes(g) >> shift; + long digitBits = reversedDigits & (0x1010101010101010L >> shift); + long digitsExt = (digitBits >> 1 | digitBits >> 2 | digitBits >> 3 | digitBits >> 4); + + long digitsOnly = Long.compress(reversedDigits, digitsExt); + + long temp = (digitsOnly & 0xf) + + 10 * ((digitsOnly >> 4) & 0xf) + + 100 * ((digitsOnly >> 8) & 0xf); + + temp *= negative; + + currentPosition += bytesToLF + 1; + + resultMap.putOrMerge(buffer, 0, nameLen, temp, hash); + pos = currentPosition; + + } + + while ((startLine = pos) < limit) { + long currentPosition = startLine; + byte b; + int offset = 0; + + while (currentPosition != segmentEnd && (b = ms.get(ValueLayout.JAVA_BYTE, currentPosition++)) != ';') { + buffer[offset++] = b; + } + // Invariant: the remaining length is less than VECTOR_SIZE, so we can just run the last round of hashing + int hash = ByteVector.fromArray(BYTE_SPECIES, buffer, 0, ASC.lt((byte) offset)) + .reinterpretAsInts() + // .mul(lastMul) + .reduceLanes(VectorOperators.ADD); + + int temp; + int negative = 1; + // Inspired by @yemreinci to unroll this even further + if (ms.get(ValueLayout.JAVA_BYTE, currentPosition) == '-') { + negative = -1; + currentPosition++; + } + if (ms.get(ValueLayout.JAVA_BYTE, currentPosition + 1) == '.') { + temp = negative * ((ms.get(ValueLayout.JAVA_BYTE, currentPosition) - '0') * 10 + (ms.get(ValueLayout.JAVA_BYTE, currentPosition + 2) - '0')); + currentPosition += 3; + } + else { + temp = negative * ((ms.get(ValueLayout.JAVA_BYTE, currentPosition) - '0') * 100 + + ((ms.get(ValueLayout.JAVA_BYTE, currentPosition + 1) - '0') * 10 + (ms.get(ValueLayout.JAVA_BYTE, currentPosition + 3) - '0'))); + currentPosition += 4; + } + if (ms.get(ValueLayout.JAVA_BYTE, currentPosition) == '\r') { + currentPosition++; + } + currentPosition++; + resultMap.putOrMerge(buffer, 0, offset, temp, hash); + pos = currentPosition; + } + return resultMap; + } + catch (IOException e) { + throw new RuntimeException(e); + } + }).parallel().flatMap(partition -> partition.getAll().stream()) + .collect(Collectors.toMap(e -> new String(e.key()), Entry::value, CalculateAverage_asun::merge, TreeMap::new)); + + System.out.println(resultsMap); + + // System.out.println(System.currentTimeMillis() - start); + + Runtime.getRuntime().halt(0); + } + + private static List getFileSegments(File file) throws IOException { + int numberOfSegments = Runtime.getRuntime().availableProcessors() * 8; + long fileSize = file.length(); + long segmentSize = fileSize / numberOfSegments; + List segments = new ArrayList<>(numberOfSegments); + // Pointless to split small files + if (segmentSize < 1_000_000) { + segments.add(new FileSegment(0, fileSize)); + return segments; + } + try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) { + for (int i = 0; i < numberOfSegments; i++) { + long segStart = i * segmentSize; + long segEnd = (i == numberOfSegments - 1) ? fileSize : segStart + segmentSize; + segStart = findSegment(i, 0, randomAccessFile, segStart, segEnd); + segEnd = findSegment(i, numberOfSegments - 1, randomAccessFile, segEnd, fileSize); + + segments.add(new FileSegment(segStart, segEnd)); + } + } + return segments; + } + + private static Result merge(Result v, Result value) { + return merge(v, value.min, value.max, value.sum, value.count); + } + + private static Result merge(Result v, long value, long value1, long value2, long value3) { + v.min = Math.min(v.min, value); + v.max = Math.max(v.max, value1); + v.sum += value2; + v.count += value3; + return v; + } + + private static long findSegment(int i, int skipSegment, RandomAccessFile raf, long location, long fileSize) throws IOException { + if (i != skipSegment) { + raf.seek(location); + while (location < fileSize) { + location++; + if (raf.read() == '\n') + break; + } + } + return location; + } + + static class Result { + long min, max, sum; + long count; + + Result(long value) { + min = max = sum = value; + this.count = 1; + } + + @Override + public String toString() { + return round(min) + "/" + round((sum / 10.0) / count) + "/" + round(max); + } + + double round(double v) { + return Math.round(v * 10.0) / 10.0; + } + + double round(long v) { + return v / 10.0; + } + + } + + record Entry(byte[] key, Result value) { + } + + record FileSegment(long start, long end) { + } + + static class ByteArrayToResultMap { + public static final int MAPSIZE = 1024 * 128; + Result[] slots = new Result[MAPSIZE]; + byte[][] keys = new byte[MAPSIZE][]; + + public void putOrMerge(byte[] key, int offset, int size, long temp, int hash) { + int slot = hash & (slots.length - 1); + var slotValue = slots[slot]; + // Linear probe for open slot + while (slotValue != null && (keys[slot].length != size || !Arrays.equals(keys[slot], 0, size, key, offset, size))) { + slot = (slot + 1) & (slots.length - 1); + slotValue = slots[slot]; + } + Result value = slotValue; + if (value == null) { + slots[slot] = new Result(temp); + byte[] bytes = new byte[size]; + System.arraycopy(key, offset, bytes, 0, size); + keys[slot] = bytes; + } + else { + value.min = Math.min(value.min, temp); + value.max = Math.max(value.max, temp); + value.sum += temp; + value.count += 1; + } + } + + // Get all pairs + public List getAll() { + List result = new ArrayList<>(slots.length); + for (int i = 0; i < slots.length; i++) { + Result slotValue = slots[i]; + if (slotValue != null) { + result.add(new Entry(keys[i], slotValue)); + } + } + return result; + } + } +}