diff --git a/calculate_average_arjenvaneerde.sh b/calculate_average_arjenvaneerde.sh new file mode 100644 index 0000000..f409990 --- /dev/null +++ b/calculate_average_arjenvaneerde.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +JAVA_OPTS="" +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_arjenvaneerde diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_arjenvaneerde.java b/src/main/java/dev/morling/onebrc/CalculateAverage_arjenvaneerde.java new file mode 100644 index 0000000..e48e2de --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_arjenvaneerde.java @@ -0,0 +1,348 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import java.io.File; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.concurrent.*; + +public class CalculateAverage_arjenvaneerde { + + static class Measure { + byte[] station; + int length; + int count; + int minTemp; + int maxTemp; + long sumTemp; + + Measure(final byte[] station, int count, int minTemp, int maxTemp, long sumTemp) { + this.station = station; + this.length = station.length; + this.count = count; + this.minTemp = minTemp; + this.maxTemp = maxTemp; + this.sumTemp = sumTemp; + } + + Measure(byte[] bytes, int startPos, int endPos, int temp) { + this.length = endPos - startPos; + this.station = new byte[this.length]; + System.arraycopy(bytes, startPos, this.station, 0, this.station.length); + this.count = 1; + this.minTemp = temp; + this.maxTemp = temp; + this.sumTemp = temp; + } + + void add(int temp) { + this.count++; + this.minTemp = Math.min(this.minTemp, temp); + this.maxTemp = Math.max(this.maxTemp, temp); + this.sumTemp += temp; + } + + public static Measure merge(Measure m1, Measure m2) { + return new Measure(m1.station, m1.count + m2.count, Math.min(m1.minTemp, m2.minTemp), Math.max(m1.maxTemp, m2.maxTemp), m1.sumTemp + m2.sumTemp); + } + + public String toString() { + return String.format("%.1f/%.1f/%.1f", this.minTemp / 10.0, this.sumTemp / (10.0 * this.count), this.maxTemp / 10.0); + } + } + + static class Measures { + static final int HASH_TABLE_SIZE = 8 * 1024; + + Measure[][] measureHashTable; + + Measures() { + this.measureHashTable = new Measure[HASH_TABLE_SIZE][20]; + } + + void add(byte[] bytes, int startPos, int endPos, int temp) { + int len = endPos - startPos; + int index = ((len - 2) & 0x0f | // 4 bits of the length + ((bytes[startPos + 0] & 0x1f) << 4) | // 5 bits of first char + ((bytes[startPos + 2] & 0x1f) << 9) // 5 bits of third char + // ((bytes[startPos + 1] & 0x1f) << 14) // 5 bits of second char + ) & (HASH_TABLE_SIZE - 1); + Measure[] arr = this.measureHashTable[index]; + int i = 0; + boolean found = false; + int arrLength = arr.length; + while (i < arrLength) { + Measure m = arr[i]; + if (m == null) { + // Not found. Add new entry. + arr[i] = new Measure(bytes, startPos, endPos, temp); + return; + } + if (m.length == len) { + switch (len) { + case 0: + break; + case 1: + if (m.station[0] == bytes[startPos]) { + found = true; + } + break; + case 2: + if (m.station[1] == bytes[startPos + 1] && + m.station[0] == bytes[startPos]) { + found = true; + } + break; + case 3: + if (m.station[2] == bytes[startPos + 2] && + m.station[1] == bytes[startPos + 1] && + m.station[0] == bytes[startPos]) { + found = true; + } + break; + case 4: + if (m.station[3] == bytes[startPos + 3] && + m.station[2] == bytes[startPos + 2] && + m.station[1] == bytes[startPos + 1] && + m.station[0] == bytes[startPos]) { + found = true; + } + break; + case 5: + if (m.station[4] == bytes[startPos + 4] && + m.station[3] == bytes[startPos + 3] && + m.station[2] == bytes[startPos + 2] && + m.station[1] == bytes[startPos + 1] && + m.station[0] == bytes[startPos]) { + found = true; + } + break; + case 6: + if (m.station[5] == bytes[startPos + 5] && + m.station[4] == bytes[startPos + 4] && + m.station[3] == bytes[startPos + 3] && + m.station[2] == bytes[startPos + 2] && + m.station[1] == bytes[startPos + 1] && + m.station[0] == bytes[startPos]) { + found = true; + } + break; + case 7: + if (m.station[6] == bytes[startPos + 6] && + m.station[5] == bytes[startPos + 5] && + m.station[4] == bytes[startPos + 4] && + m.station[3] == bytes[startPos + 3] && + m.station[2] == bytes[startPos + 2] && + m.station[1] == bytes[startPos + 1] && + m.station[0] == bytes[startPos]) { + found = true; + } + break; + default: + found = Arrays.mismatch(m.station, 0, len, bytes, startPos, endPos) == -1; + break; + } + if (found) { + // Add info. + m.add(temp); + return; + } + } + i++; + } + // throw new RuntimeException("Reached end of Measures array."); + Measure[] newArr = new Measure[arr.length * 2]; + System.arraycopy(arr, 0, newArr, 0, arr.length); + newArr[i] = new Measure(bytes, startPos, endPos, temp); + this.measureHashTable[index] = newArr; + } + } + + private static class BytesProcessor implements Runnable { + final Measures measures; + final byte[] buffer; + long absoluteStartPos; + int startPos; + int endPos; + + BytesProcessor(byte[] buffer, long absoluteStartPos, int startPos, int endPos, final Measures measures) { + this.buffer = buffer; + this.absoluteStartPos = absoluteStartPos; + this.startPos = startPos; + this.endPos = endPos; + this.measures = measures; + } + + @Override + public void run() { + int startOfLinePos = startPos; + int endOfLinePos = startOfLinePos; + int sepPos; + // Process all lines + while (endOfLinePos < endPos) { + while (buffer[endOfLinePos] != ';') { + endOfLinePos++; + } + sepPos = endOfLinePos; + int temperature = 0; + byte lineByte; + endOfLinePos++; + lineByte = buffer[endOfLinePos]; + while (lineByte != 0x0A) { + if (lineByte >= '0' && lineByte <= '9') { + temperature = temperature * 10 + (lineByte - '0'); + } + lineByte = buffer[++endOfLinePos]; + } + if (buffer[sepPos + 1] == '-') { + temperature = -temperature; + } + measures.add(buffer, startOfLinePos, sepPos, temperature); + startOfLinePos = ++endOfLinePos; + } + } + } + + private static class MeasuresMergeProcessor implements Runnable { + final Measures[] measures; + final int startIndex; + final int endIndex; + final ConcurrentSkipListMap results; + + MeasuresMergeProcessor(final Measures[] measures, + final int startIndex, + final int endIndex, + final ConcurrentSkipListMap results) { + this.measures = measures; + this.startIndex = startIndex; + this.endIndex = endIndex; + this.results = results; + } + + @Override + public void run() { + for (int mIdx = 0; mIdx < this.measures.length; mIdx++) { + for (int hashIdx = this.startIndex; hashIdx < this.endIndex; hashIdx++) { + Measure[] mArr = this.measures[mIdx].measureHashTable[hashIdx]; + int i = 0; + Measure measure = mArr[i]; + while (measure != null) { + Measure finalMeasure = measure; + this.results.compute(new String(measure.station, StandardCharsets.UTF_8), (k, v) -> v == null ? finalMeasure : Measure.merge(v, finalMeasure)); + measure = mArr[++i]; + } + } + } + } + } + + private static final String FILE = "./measurements.txt"; + // private static final String FILE = "./src/test/resources/samples/measurements-1.txt"; + // private static final String FILE = "./src/test/resources/samples/measurements-10000-unique-keys.txt"; + private static final int NUM_THREADS = 8; // Runtime.getRuntime().availableProcessors(); + private static final int BYTE_BUFFER_SIZE = 16 * 1024 * 1024; + private static final ExecutorService threads = Executors.newFixedThreadPool(NUM_THREADS); + private static final List> futures = new ArrayList<>(NUM_THREADS); + private static final Measures[] measures = new Measures[NUM_THREADS]; + + public static void main(String[] args) throws Exception { + long startTime = System.currentTimeMillis(); + for (int i = 0; i < NUM_THREADS; i++) { + measures[i] = new Measures(); + } + File file = new File(FILE); + try (RandomAccessFile raFile = new RandomAccessFile(file, "r"); + FileChannel inChannel = raFile.getChannel()) { + + ByteBuffer[] byteBuffers = new ByteBuffer[2]; + byteBuffers[0] = ByteBuffer.allocate(BYTE_BUFFER_SIZE); + byteBuffers[1] = ByteBuffer.allocate(BYTE_BUFFER_SIZE); + int currrentByteBuffer = 0; + long numBytesRead = 0; + long numTotBytesRead = 0; + numBytesRead = inChannel.read(byteBuffers[currrentByteBuffer]); + byteBuffers[currrentByteBuffer].flip(); + while (numBytesRead > 0 || byteBuffers[currrentByteBuffer].position() > 0) { + if (numBytesRead > 0) { + numTotBytesRead += numBytesRead; + } + // System.out.println((double) numTotBytesRead/(1024*1024)); + // Distribute buffer chunks across threads. + int bufferLimit = byteBuffers[currrentByteBuffer].limit(); + byte[] fileBytes = byteBuffers[currrentByteBuffer].array(); + int chunckSize = bufferLimit / NUM_THREADS; + int startOfChunkPos = 0; + int endOfChunkPos = 0; + for (int chunk = 0; chunk < NUM_THREADS; chunk++) { + if (chunk == NUM_THREADS - 1) { + endOfChunkPos = bufferLimit - 1; + } + else { + endOfChunkPos = Math.min((chunk + 1) * chunckSize, bufferLimit - 1); + } + while (endOfChunkPos > startOfChunkPos && + fileBytes[endOfChunkPos] != 0x0A) { + endOfChunkPos--; + } + if (endOfChunkPos > startOfChunkPos) { + endOfChunkPos++; + Future f = threads.submit(new BytesProcessor(fileBytes, numTotBytesRead - numBytesRead, startOfChunkPos, endOfChunkPos, measures[chunk]), + chunk); + futures.add(f); + startOfChunkPos = endOfChunkPos; + } + } + int nextByteBuffer = (currrentByteBuffer + 1) % 2; + byteBuffers[nextByteBuffer].clear(); + if (numBytesRead > 0) { + // Copy over remaining bytes to next buffer. + for (; endOfChunkPos < bufferLimit; endOfChunkPos++) { + byteBuffers[nextByteBuffer].put(fileBytes[endOfChunkPos]); + } + // Read next set from channel. + numBytesRead = inChannel.read(byteBuffers[nextByteBuffer]); + byteBuffers[nextByteBuffer].flip(); + // Wait for all threads to finish. + for (Future future : futures) { + future.get(); + } + futures.clear(); + } + currrentByteBuffer = nextByteBuffer; + } + } + ConcurrentSkipListMap measurements = new ConcurrentSkipListMap<>(); + int chunkSize = Measures.HASH_TABLE_SIZE / measures.length; + for (int i = 0; i < measures.length; i++) { + Future f = threads.submit(new MeasuresMergeProcessor(measures, i * chunkSize, (i + 1) * chunkSize, measurements), i); + futures.add(f); + } + for (Future future : futures) { + future.get(); + } + futures.clear(); + threads.shutdown(); + threads.awaitTermination(1, TimeUnit.MILLISECONDS); + System.out.println(measurements); + // long endTime = System.currentTimeMillis(); + // System.out.printf("Duration : %.3f%n", (endTime - startTime) / 1000.0); + } + +}