From f4a0039a591fc7c02306af5fc7a8fbca8a292668 Mon Sep 17 00:00:00 2001 From: tivrfoa Date: Mon, 29 Jan 2024 17:24:04 -0300 Subject: [PATCH] Try more chunks than threads, and of different sizes (#644) /** * Solution based on thomaswue solution, commit: * commit d0a28599c293d3afe3291fc3cf169a7b25ae9ae6 * Author: Thomas Wuerthinger * Date: Sun Jan 21 20:13:48 2024 +0100 * * The goal here was to try to improve the runtime of his 10k * solution of: 00:04.516 * * With Thomas latest changes, his time is probably much better * already, and maybe even 1st place for the 10k too. * See: https://github.com/gunnarmorling/1brc/pull/606 * * But as I was already coding something, I'll submit just to * see if it will be faster than his *previous* 10k time of * 00:04.516 * * Changes: * It's a similar idea of my previous solution, that if you split * the chunks evenly, some threads might finish much faster and * stay idle, so: * 1) Create more chunks than threads, so the ones that finish first * can do something; * 2) Decrease chunk sizes as we get closer to the end of the file. */ --- prepare_tivrfoa.sh | 2 +- .../onebrc/CalculateAverage_tivrfoa.java | 364 ++++++++++-------- 2 files changed, 197 insertions(+), 169 deletions(-) diff --git a/prepare_tivrfoa.sh b/prepare_tivrfoa.sh index 7cbf309..024d6f9 100755 --- a/prepare_tivrfoa.sh +++ b/prepare_tivrfoa.sh @@ -20,7 +20,7 @@ sdk use java 21.0.2-graal 1>&2 # ./mvnw clean verify removes target/ and will re-trigger native image creation. if [ ! -f target/CalculateAverage_tivrfoa_image ]; then - NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_tivrfoa\$Scanner" + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -H:-GenLoopSafepoints -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_tivrfoa\$Scanner" # Use -H:MethodFilter=CalculateAverage_tivrfoa.* -H:Dump=:2 -H:PrintGraph=Network for IdealGraphVisualizer graph dumping. native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_tivrfoa_image dev.morling.onebrc.CalculateAverage_tivrfoa fi diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_tivrfoa.java b/src/main/java/dev/morling/onebrc/CalculateAverage_tivrfoa.java index a1b4844..54f13cb 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_tivrfoa.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_tivrfoa.java @@ -23,31 +23,35 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.*; -import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; /** * Solution based on thomaswue solution, commit: * commit d0a28599c293d3afe3291fc3cf169a7b25ae9ae6 - * Author: Thomas Wuerthinger + * Author: Thomas Wuerthinger * Date: Sun Jan 21 20:13:48 2024 +0100 * + * The goal here was to try to improve the runtime of his 10k + * solution of: 00:04.516 + * + * With Thomas latest changes, his time is probably much better + * already, and maybe even 1st place for the 10k too. + * See: https://github.com/gunnarmorling/1brc/pull/606 + * + * But as I was already coding something, I'll submit just to + * see if it will be faster than his *previous* 10k time of + * 00:04.516 + * * Changes: - * 1) Use LinkedBlockingQueue to store partial results, that - * will then be merged into the final map later. - * As different chunks finish at different times, this allows - * to process them as they finish, instead of joining the - * threads sequentially. - * This change seems more useful for the 10k dataset, as the - * runtime difference of each chunk is greater. - * 2) Use only 4 threads if the file is >= 14GB. - * This showed much better results on my local test, but I only - * run with 200 million rows (because of limited RAM), and I have - * no idea how it will perform on the 1brc HW. + * It's a similar idea of my previous solution, that if you split + * the chunks evenly, some threads might finish much faster and + * stay idle, so: + * 1) Create more chunks than threads, so the ones that finish first + * can do something; + * 2) Decrease chunk sizes as we get closer to the end of the file. */ public class CalculateAverage_tivrfoa { private static final String FILE = "./measurements.txt"; - private static LinkedBlockingQueue> partialResultQueue; - private static int C = 10_000; private static final int MIN_TEMP = -999; private static final int MAX_TEMP = 999; @@ -95,8 +99,16 @@ public class CalculateAverage_tivrfoa { } } + private static final int NUM_CPUS = Runtime.getRuntime().availableProcessors(); + private static final AtomicInteger chunkIdx = new AtomicInteger(); + private static long[] chunks; + private static int numChunks; + private static final class SolveChunk extends Thread { private long chunkStart, chunkEnd; + private Result[] results = new Result[10_000]; + private Result[] buckets = new Result[1 << 17]; + private int resIdx = 0; public SolveChunk(long chunkStart, long chunkEnd) { this.chunkStart = chunkStart; @@ -105,12 +117,132 @@ public class CalculateAverage_tivrfoa { @Override public void run() { - try { - partialResultQueue.put(parseLoop(chunkStart, chunkEnd)); + parseLoop(); + int chunk = chunkIdx.getAndIncrement(); + if (chunk < numChunks) { + chunkStart = chunks[chunk]; + chunkEnd = chunks[chunk + 1]; + run(); } - catch (Exception e) { - e.printStackTrace(); - System.exit(1); + } + + private void parseLoop() { + Scanner scanner = new Scanner(chunkStart, chunkEnd); + long word = scanner.getLong(); + long pos = findDelimiter(word); + while (scanner.hasNext()) { + long nameAddress = scanner.pos(); + long hash = 0; + + // Search for ';', one long at a time. + if (pos != 0) { + pos = Long.numberOfTrailingZeros(pos) >>> 3; + scanner.add(pos); + word = mask(word, pos); + hash = word; + + int number = scanNumber(scanner); + long nextWord = scanner.getLong(); + long nextPos = findDelimiter(nextWord); + + Result existingResult = buckets[hashToIndex(hash, buckets)]; + if (existingResult != null && existingResult.lastNameLong == word) { + word = nextWord; + pos = nextPos; + record(existingResult, number); + continue; + } + + scanner.setPos(nameAddress + pos); + } + else { + scanner.add(8); + hash = word; + long prevWord = word; + word = scanner.getLong(); + pos = findDelimiter(word); + if (pos != 0) { + pos = Long.numberOfTrailingZeros(pos) >>> 3; + scanner.add(pos); + word = mask(word, pos); + hash ^= word; + + Result existingResult = buckets[hashToIndex(hash, buckets)]; + if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { + int number = scanNumber(scanner); + word = scanner.getLong(); + pos = findDelimiter(word); + record(existingResult, number); + continue; + } + } + else { + scanner.add(8); + hash ^= word; + while (true) { + word = scanner.getLong(); + pos = findDelimiter(word); + if (pos != 0) { + pos = Long.numberOfTrailingZeros(pos) >>> 3; + scanner.add(pos); + word = mask(word, pos); + hash ^= word; + break; + } + else { + scanner.add(8); + hash ^= word; + } + } + } + } + + // Save length of name for later. + int nameLength = (int) (scanner.pos() - nameAddress); + int number = scanNumber(scanner); + + // Final calculation for index into hash table. + int tableIndex = hashToIndex(hash, buckets); + outer: while (true) { + Result existingResult = buckets[tableIndex]; + if (existingResult == null) { + existingResult = newEntry(buckets, nameAddress, tableIndex, nameLength, scanner); + results[resIdx++] = existingResult; + } + // Check for collision. + int i = 0; + int namePos = 0; + for (; i < nameLength + 1 - 8; i += 8) { + if (namePos >= existingResult.name.length || existingResult.name[namePos++] != scanner.getLongAt(nameAddress + i)) { + tableIndex = (tableIndex + 31) & (buckets.length - 1); + continue outer; + } + } + + int remainingShift = (64 - (nameLength + 1 - i) << 3); + if (((existingResult.lastNameLong ^ (scanner.getLongAt(nameAddress + i) << remainingShift)) == 0)) { + record(existingResult, number); + break; + } + else { + // Collision error, try next. + tableIndex = (tableIndex + 31) & (buckets.length - 1); + } + } + + word = scanner.getLong(); + pos = findDelimiter(word); + } + } + } + + private static void mergeIntoFinalMap(TreeMap map, Result[] newResults) { + for (var r : newResults) { + if (r == null) + return; + Result current = map.putIfAbsent(r.calcName(), r); + if (current != null) { + current.add(r); } } } @@ -127,20 +259,23 @@ public class CalculateAverage_tivrfoa { spawnWorker(); return; } - final int cpus = Runtime.getRuntime().availableProcessors(); - final long[] chunks = getSegments(cpus); - final int workers = chunks.length - 1; - partialResultQueue = new LinkedBlockingQueue<>(workers); - final SolveChunk[] threads = new SolveChunk[workers]; - for (int i = 0; i < workers; i++) { + + chunks = getSegments(NUM_CPUS); + numChunks = chunks.length - 1; + final SolveChunk[] threads = new SolveChunk[NUM_CPUS]; + chunkIdx.set(NUM_CPUS); + for (int i = 0; i < NUM_CPUS; i++) { threads[i] = new SolveChunk(chunks[i], chunks[i + 1]); threads[i].start(); } - final TreeMap ret = new TreeMap<>(); - for (int i = 0; i < workers; ++i) { - accumulateResults(ret, partialResultQueue.take()); + + TreeMap map = new TreeMap<>(); + for (int i = 0; i < NUM_CPUS; ++i) { + threads[i].join(); + mergeIntoFinalMap(map, threads[i].results); } - System.out.println(ret); + + System.out.println(map); System.out.close(); } @@ -159,129 +294,6 @@ public class CalculateAverage_tivrfoa { .transferTo(System.out); } - private static void accumulateResults(TreeMap result, List newResult) { - for (Result r : newResult) { - String name = r.calcName(); - Result current = result.putIfAbsent(name, r); - if (current != null) { - current.add(r); - } - } - } - - // Main parse loop. - private static ArrayList parseLoop(long chunkStart, long chunkEnd) { - ArrayList ret = new ArrayList<>(C); - Result[] results = new Result[1 << 17]; - Scanner scanner = new Scanner(chunkStart, chunkEnd); - long word = scanner.getLong(); - long pos = findDelimiter(word); - while (scanner.hasNext()) { - long nameAddress = scanner.pos(); - long hash = 0; - - // Search for ';', one long at a time. - if (pos != 0) { - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); - hash = word; - - int number = scanNumber(scanner); - long nextWord = scanner.getLong(); - long nextPos = findDelimiter(nextWord); - - Result existingResult = results[hashToIndex(hash, results)]; - if (existingResult != null && existingResult.lastNameLong == word) { - word = nextWord; - pos = nextPos; - record(existingResult, number); - continue; - } - - scanner.setPos(nameAddress + pos); - } - else { - scanner.add(8); - hash = word; - long prevWord = word; - word = scanner.getLong(); - pos = findDelimiter(word); - if (pos != 0) { - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); - hash ^= word; - - Result existingResult = results[hashToIndex(hash, results)]; - if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { - int number = scanNumber(scanner); - word = scanner.getLong(); - pos = findDelimiter(word); - record(existingResult, number); - continue; - } - } - else { - scanner.add(8); - hash ^= word; - while (true) { - word = scanner.getLong(); - pos = findDelimiter(word); - if (pos != 0) { - pos = Long.numberOfTrailingZeros(pos) >>> 3; - scanner.add(pos); - word = mask(word, pos); - hash ^= word; - break; - } - else { - scanner.add(8); - hash ^= word; - } - } - } - } - - // Save length of name for later. - int nameLength = (int) (scanner.pos() - nameAddress); - int number = scanNumber(scanner); - - // Final calculation for index into hash table. - int tableIndex = hashToIndex(hash, results); - outer: while (true) { - Result existingResult = results[tableIndex]; - if (existingResult == null) { - existingResult = newEntry(results, nameAddress, tableIndex, nameLength, scanner); - ret.add(existingResult); - } - // Check for collision. - int i = 0; - int namePos = 0; - for (; i < nameLength + 1 - 8; i += 8) { - if (namePos >= existingResult.name.length || existingResult.name[namePos++] != scanner.getLongAt(nameAddress + i)) { - tableIndex = (tableIndex + 31) & (results.length - 1); - continue outer; - } - } - - int remainingShift = (64 - (nameLength + 1 - i) << 3); - if (((existingResult.lastNameLong ^ (scanner.getLongAt(nameAddress + i) << remainingShift)) == 0)) { - record(existingResult, number); - break; - } - else { - // Collision error, try next. - tableIndex = (tableIndex + 31) & (results.length - 1); - } - } - - word = scanner.getLong(); - pos = findDelimiter(word); - } - return ret; - } - private static int scanNumber(Scanner scanPtr) { scanPtr.add(1); long numberWord = scanPtr.getLong(); @@ -356,28 +368,44 @@ public class CalculateAverage_tivrfoa { return r; } + /** + * - Split 70% of the file in even chunks for all cpus; + * - Create smaller chunks for the remainder of the file. + */ private static long[] getSegments(int cpus) throws IOException { try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { - long fileSize = fileChannel.size(); - int numberOfChunks = cpus / 2; - if (fileSize < (int) 14e9) { - C = 500; - numberOfChunks = cpus; - } - long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks; - long[] chunks = new long[numberOfChunks + 1]; - long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, java.lang.foreign.Arena.global()).address(); + final long fileSize = fileChannel.size(); + final long part1 = (long) (fileSize * 0.7); + final long part2 = (long) (fileSize * 0.2); + final long part3 = fileSize - part1 - part2; + final long bigChunkSize = (part1 - 1) / cpus; + final long smallChunkSize1 = (part2 - 1) / (cpus * 3); + final long smallChunkSize2 = (part3 - 1) / (cpus * 3); + final int numChunks = cpus + cpus * 3 + cpus * 3; + final long[] sizes = new long[numChunks]; + int l = 0, r = cpus; + Arrays.fill(sizes, l, r, bigChunkSize); + l = r; + r = l + cpus * 3; + Arrays.fill(sizes, l, r, smallChunkSize1); + l = r; + r = l + cpus * 3; + Arrays.fill(sizes, l, r, smallChunkSize2); + final long[] chunks = new long[sizes.length + 1]; + final long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, java.lang.foreign.Arena.global()).address(); chunks[0] = mappedAddress; - long endAddress = mappedAddress + fileSize; - Scanner s = new Scanner(mappedAddress, mappedAddress + fileSize); - for (int i = 1; i < numberOfChunks; ++i) { - long chunkAddress = mappedAddress + i * segmentSize; + final long endAddress = mappedAddress + fileSize; + final Scanner s = new Scanner(mappedAddress, mappedAddress + fileSize); + for (int i = 1, sizeIdx = 0; i < chunks.length - 1; ++i, sizeIdx = (sizeIdx + 1) % sizes.length) { + long chunkAddress = chunks[i - 1] + sizes[sizeIdx]; // Align to first row start. while (chunkAddress < endAddress && (s.getLongAt(chunkAddress++) & 0xFF) != '\n') ; chunks[i] = Math.min(chunkAddress, endAddress); + // System.err.printf("Chunk size %d\n", chunks[i] - chunks[i - 1]); } - chunks[numberOfChunks] = endAddress; + chunks[chunks.length - 1] = endAddress; + // System.err.printf("Chunk size %d\n", chunks[chunks.length - 1] - chunks[chunks.length - 2]); return chunks; } } @@ -428,4 +456,4 @@ public class CalculateAverage_tivrfoa { this.pos = l; } } -} \ No newline at end of file +}