diff --git a/calculate_average_vaidhy.sh b/calculate_average_vaidhy.sh new file mode 100755 index 0000000..ca204f8 --- /dev/null +++ b/calculate_average_vaidhy.sh @@ -0,0 +1,19 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +JAVA_OPTS="--enable-preview" +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_vaidhy diff --git a/prepare_vaidhy.sh b/prepare_vaidhy.sh new file mode 100755 index 0000000..f83a3ff --- /dev/null +++ b/prepare_vaidhy.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +source "$HOME/.sdkman/bin/sdkman-init.sh" +sdk use java 21.0.1-graal 1>&2 diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java b/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java new file mode 100644 index 0000000..5795077 --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java @@ -0,0 +1,427 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import sun.misc.Unsafe; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.*; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.function.Function; +import java.util.function.Supplier; + +public class CalculateAverage_vaidhy { + + private static final class HashEntry { + private long startAddress; + private long endAddress; + private long suffix; + private int hash; + + IntSummaryStatistics value; + } + + private static class PrimitiveHashMap { + private final HashEntry[] entries; + private final int twoPow; + + PrimitiveHashMap(int twoPow) { + this.twoPow = twoPow; + this.entries = new HashEntry[1 << twoPow]; + for (int i = 0; i < entries.length; i++) { + this.entries[i] = new HashEntry(); + } + } + + public HashEntry find(long startAddress, long endAddress, long suffix, int hash) { + int len = entries.length; + int i = (hash ^ (hash >> twoPow)) & (len - 1); + + do { + HashEntry entry = entries[i]; + if (entry.value == null) { + return entry; + } + if (entry.hash == hash) { + long entryLength = entry.endAddress - entry.startAddress; + long lookupLength = endAddress - startAddress; + if ((entryLength == lookupLength) && (entry.suffix == suffix)) { + boolean found = compareEntryKeys(startAddress, endAddress, entry); + + if (found) { + return entry; + } + } + } + i++; + if (i == len) { + i = 0; + } + } while (i != hash); + return null; + } + + private static boolean compareEntryKeys(long startAddress, long endAddress, HashEntry entry) { + long entryIndex = entry.startAddress; + long lookupIndex = startAddress; + + for (; (lookupIndex + 7) < endAddress; lookupIndex += 8) { + if (UNSAFE.getLong(entryIndex) != UNSAFE.getLong(lookupIndex)) { + return false; + } + entryIndex += 8; + } + return true; + } + } + + private static final String FILE = "./measurements.txt"; + + private static Unsafe initUnsafe() { + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + return (Unsafe) theUnsafe.get(Unsafe.class); + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private static final Unsafe UNSAFE = initUnsafe(); + + private static int parseDouble(long startAddress, long endAddress) { + int normalized; + int length = (int) (endAddress - startAddress); + if (length == 5) { + normalized = (UNSAFE.getByte(startAddress + 1) ^ 0x30); + normalized = (normalized << 3) + (normalized << 1) + (UNSAFE.getByte(startAddress + 2) ^ 0x30); + normalized = (normalized << 3) + (normalized << 1) + (UNSAFE.getByte(startAddress + 4) ^ 0x30); + normalized = -normalized; + return normalized; + } + if (length == 3) { + normalized = (UNSAFE.getByte(startAddress) ^ 0x30); + normalized = (normalized << 3) + (normalized << 1) + (UNSAFE.getByte(startAddress + 2) ^ 0x30); + return normalized; + } + + if (UNSAFE.getByte(startAddress) == '-') { + normalized = (UNSAFE.getByte(startAddress + 1) ^ 0x30); + normalized = (normalized << 3) + (normalized << 1) + (UNSAFE.getByte(startAddress + 3) ^ 0x30); + normalized = -normalized; + return normalized; + } + else { + normalized = (UNSAFE.getByte(startAddress) ^ 0x30); + normalized = (normalized << 3) + (normalized << 1) + (UNSAFE.getByte(startAddress + 1) ^ 0x30); + normalized = (normalized << 3) + (normalized << 1) + (UNSAFE.getByte(startAddress + 3) ^ 0x30); + return normalized; + } + } + + interface MapReduce { + + void process(long keyStartAddress, long keyEndAddress, int hash, int temperature, long suffix); + + I result(); + } + + private final FileService fileService; + private final Supplier> chunkProcessCreator; + private final Function, T> reducer; + + interface FileService { + long length(); + + long address(); + } + + CalculateAverage_vaidhy(FileService fileService, + Supplier> mapReduce, + Function, T> reducer) { + this.fileService = fileService; + this.chunkProcessCreator = mapReduce; + this.reducer = reducer; + } + + static class LineStream { + private final long fileEnd; + private final long chunkEnd; + + private long position; + private int hash; + private long suffix; + byte[] b = new byte[4]; + + public LineStream(FileService fileService, long offset, long chunkSize) { + long fileStart = fileService.address(); + this.fileEnd = fileStart + fileService.length(); + this.chunkEnd = fileStart + offset + chunkSize; + this.position = fileStart + offset; + this.hash = 0; + } + + public boolean hasNext() { + return position <= chunkEnd && position < fileEnd; + } + + public long findSemi() { + int h = 0; + long s = 0; + long i = position; + while ((i + 3) < fileEnd) { + // Adding 16 as it is the offset for primitive arrays + ByteBuffer.wrap(b).putInt(UNSAFE.getInt(i)); + + if (b[3] == 0x3B) { + break; + } + i++; + h = ((h << 5) - h) ^ b[3]; + s = (s << 8) ^ b[3]; + + if (b[2] == 0x3B) { + break; + } + i++; + h = ((h << 5) - h) ^ b[2]; + s = (s << 8) ^ b[2]; + + if (b[1] == 0x3B) { + break; + } + i++; + h = ((h << 5) - h) ^ b[1]; + s = (s << 8) ^ b[1]; + + if (b[0] == 0x3B) { + break; + } + i++; + h = ((h << 5) - h) ^ b[0]; + s = (s << 8) ^ b[0]; + } + + this.hash = h; + this.suffix = s; + position = i + 1; + return i; + } + + public long skipLine() { + for (long i = position; i < fileEnd; i++) { + byte ch = UNSAFE.getByte(i); + if (ch == 0x0a) { + position = i + 1; + return i; + } + } + position = fileEnd; + return fileEnd; + } + + public long findTemperature() { + position += 3; + for (long i = position; i < fileEnd; i++) { + byte ch = UNSAFE.getByte(i); + if (ch == 0x0a) { + position = i + 1; + return i; + } + } + position = fileEnd; + return fileEnd; + } + } + + private void worker(long offset, long chunkSize, MapReduce lineConsumer) { + LineStream lineStream = new LineStream(fileService, offset, chunkSize); + + if (offset != 0) { + if (lineStream.hasNext()) { + // Skip the first line. + lineStream.skipLine(); + } + else { + // No lines then do nothing. + return; + } + } + while (lineStream.hasNext()) { + long keyStartAddress = lineStream.position; + long keyEndAddress = lineStream.findSemi(); + long keySuffix = lineStream.suffix; + int keyHash = lineStream.hash; + long valueStartAddress = lineStream.position; + long valueEndAddress = lineStream.findTemperature(); + int temperature = parseDouble(valueStartAddress, valueEndAddress); + lineConsumer.process(keyStartAddress, keyEndAddress, keyHash, temperature, keySuffix); + } + } + + public T master(long chunkSize, ExecutorService executor) { + long len = fileService.length(); + List> summaries = new ArrayList<>(); + + for (long offset = 0; offset < len; offset += chunkSize) { + long workerLength = Math.min(len, offset + chunkSize) - offset; + MapReduce mr = chunkProcessCreator.get(); + final long transferOffset = offset; + Future task = executor.submit(() -> { + worker(transferOffset, workerLength, mr); + return mr.result(); + }); + summaries.add(task); + } + List summariesDone = summaries.stream() + .map(task -> { + try { + return task.get(); + } + catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }) + .toList(); + return reducer.apply(summariesDone); + } + + static class DiskFileService implements FileService { + private final long fileSize; + private final long mappedAddress; + + DiskFileService(String fileName) throws IOException { + FileChannel fileChannel = FileChannel.open(Path.of(fileName), + StandardOpenOption.READ); + this.fileSize = fileChannel.size(); + this.mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, + fileSize, Arena.global()).address(); + } + + @Override + public long length() { + return fileSize; + } + + @Override + public long address() { + return mappedAddress; + } + } + + private static class ChunkProcessorImpl implements MapReduce { + + // 1 << 14 > 10,000 so it works + private final PrimitiveHashMap statistics = new PrimitiveHashMap(14); + + @Override + public void process(long keyStartAddress, long keyEndAddress, int hash, int temperature, long suffix) { + HashEntry entry = statistics.find(keyStartAddress, keyEndAddress, suffix, hash); + if (entry == null) { + throw new IllegalStateException("Hash table too small :("); + } + if (entry.value == null) { + entry.startAddress = keyStartAddress; + entry.endAddress = keyEndAddress; + entry.suffix = suffix; + entry.hash = hash; + entry.value = new IntSummaryStatistics(); + } + entry.value.accept(temperature); + } + + @Override + public PrimitiveHashMap result() { + return statistics; + } + } + + public static void main(String[] args) throws IOException { + DiskFileService diskFileService = new DiskFileService(FILE); + + CalculateAverage_vaidhy> calculateAverageVaidhy = new CalculateAverage_vaidhy<>( + diskFileService, + ChunkProcessorImpl::new, + CalculateAverage_vaidhy::combineOutputs); + + int proc = 2 * Runtime.getRuntime().availableProcessors(); + + long fileSize = diskFileService.length(); + long chunkSize = Math.ceilDiv(fileSize, proc); + + ExecutorService executor = Executors.newFixedThreadPool(proc); + Map output = calculateAverageVaidhy.master(chunkSize, executor); + executor.shutdown(); + + Map outputStr = toPrintMap(output); + System.out.println(outputStr); + } + + private static Map toPrintMap(Map output) { + + Map outputStr = new TreeMap<>(); + for (Map.Entry entry : output.entrySet()) { + IntSummaryStatistics stat = entry.getValue(); + outputStr.put(entry.getKey(), + STR."\{stat.getMin() / 10.0}/\{Math.round(stat.getAverage()) / 10.0}/\{stat.getMax() / 10.0}"); + } + return outputStr; + } + + private static Map combineOutputs( + List list) { + + Map output = new HashMap<>(10000); + for (PrimitiveHashMap map : list) { + for (HashEntry entry : map.entries) { + if (entry.value != null) { + String keyStr = unsafeToString(entry.startAddress, entry.endAddress); + + output.compute(keyStr, (ignore, val) -> { + if (val == null) { + return entry.value; + } + else { + val.combine(entry.value); + return val; + } + }); + } + } + } + + return output; + } + + private static String unsafeToString(long startAddress, long endAddress) { + byte[] keyBytes = new byte[(int) (endAddress - startAddress)]; + for (int i = 0; i < keyBytes.length; i++) { + keyBytes[i] = UNSAFE.getByte(startAddress + i); + } + return new String(keyBytes, StandardCharsets.UTF_8); + } +}