Sixth attempt CalculateAverage_zerninv.java (#407)

* rethink chunking

* fix typo
This commit is contained in:
zerninv 2024-01-15 19:25:52 +00:00 committed by GitHub
parent dd9a3dde7e
commit d18b10708b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,14 +25,15 @@ import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.StandardOpenOption; import java.nio.file.StandardOpenOption;
import java.util.*; import java.util.ArrayList;
import java.util.concurrent.ExecutionException; import java.util.List;
import java.util.concurrent.Executors; import java.util.TreeMap;
import java.util.concurrent.Future;
public class CalculateAverage_zerninv { public class CalculateAverage_zerninv {
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static final int MIN_FILE_SIZE = 1024 * 1024 * 16; private static final int L3_CACHE_SIZE = 128 * 1024 * 1024;
private static final int CORES = Runtime.getRuntime().availableProcessors();
private static final int CHUNK_SIZE = (L3_CACHE_SIZE - MeasurementContainer.SIZE * MeasurementContainer.ENTRY_SIZE * CORES) / CORES - 1024 * CORES;
// #.## // #.##
private static final int THREE_DIGITS_MASK = 0x2e0000; private static final int THREE_DIGITS_MASK = 0x2e0000;
@ -48,48 +49,49 @@ public class CalculateAverage_zerninv {
private static final Unsafe UNSAFE = initUnsafe(); private static final Unsafe UNSAFE = initUnsafe();
public static void main(String[] args) throws IOException { public static void main(String[] args) throws IOException, InterruptedException {
var results = new HashMap<String, MeasurementAggregation>();
try (var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { try (var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
var fileSize = channel.size(); var fileSize = channel.size();
var memorySegment = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()); var minChunkSize = Math.min(fileSize, CHUNK_SIZE);
long address = memorySegment.address();
var cores = Runtime.getRuntime().availableProcessors();
var minChunkSize = fileSize < MIN_FILE_SIZE ? fileSize : fileSize / cores;
var chunks = splitByChunks(address, address + fileSize, minChunkSize);
var executor = Executors.newFixedThreadPool(cores); var tasks = new TaskThread[CORES];
List<Future<Map<String, MeasurementAggregation>>> fResults = new ArrayList<>(); for (int i = 0; i < tasks.length; i++) {
for (int i = 1; i < chunks.size(); i++) { tasks[i] = new TaskThread(new MeasurementContainer(), (int) (fileSize / minChunkSize / CORES + 1));
final long prev = chunks.get(i - 1);
final long curr = chunks.get(i);
fResults.add(executor.submit(() -> calcForChunk(prev, curr)));
} }
fResults.forEach(f -> { var memorySegment = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global());
try { var address = memorySegment.address();
f.get().forEach((key, value) -> { var chunks = splitByChunks(address, address + fileSize, minChunkSize);
var result = results.get(key); for (int i = 0; i < chunks.size() - 1; i++) {
if (result != null) { var task = tasks[i % CORES];
result.merge(value); task.addChunk(chunks.get(i), chunks.get(i + 1));
}
for (var task : tasks) {
task.start();
}
var results = new TreeMap<String, TemperatureAggregation>();
for (var task : tasks) {
task.join();
task.measurements()
.forEach(measurement -> {
var aggr = results.get(measurement.station());
if (aggr == null) {
results.put(measurement.station(), measurement.aggregation());
} }
else { else {
results.put(key, value); aggr.merge(measurement.aggregation());
} }
}); });
} }
catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
});
executor.shutdown();
}
var bos = new BufferedOutputStream(System.out); var bos = new BufferedOutputStream(System.out);
bos.write(new TreeMap<>(results).toString().getBytes(StandardCharsets.UTF_8)); bos.write(new TreeMap<>(results).toString().getBytes(StandardCharsets.UTF_8));
bos.write('\n'); bos.write('\n');
bos.flush(); bos.flush();
} }
}
private static Unsafe initUnsafe() { private static Unsafe initUnsafe() {
try { try {
@ -103,7 +105,7 @@ public class CalculateAverage_zerninv {
} }
private static List<Long> splitByChunks(long address, long end, long minChunkSize) { private static List<Long> splitByChunks(long address, long end, long minChunkSize) {
List<Long> result = new ArrayList<>(); List<Long> result = new ArrayList<>((int) ((end - address) / minChunkSize + 1));
result.add(address); result.add(address);
while (address < end) { while (address < end) {
address += Math.min(end - address, minChunkSize); address += Math.min(end - address, minChunkSize);
@ -114,60 +116,20 @@ public class CalculateAverage_zerninv {
return result; return result;
} }
private static Map<String, MeasurementAggregation> calcForChunk(long offset, long end) { private static final class TemperatureAggregation {
var results = new MeasurementContainer();
long cityOffset;
int hashCode, temperature, word;
byte cityNameSize, b;
while (offset < end) {
cityOffset = offset;
hashCode = 0;
while ((b = UNSAFE.getByte(offset++)) != DELIMITER) {
hashCode = hashCode * 31 + b;
}
cityNameSize = (byte) (offset - cityOffset - 1);
word = UNSAFE.getInt(offset);
offset += 4;
if ((word & TWO_NEGATIVE_DIGITS_MASK) == TWO_NEGATIVE_DIGITS_MASK) {
word >>>= 8;
temperature = ZERO * 11 - ((word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK));
}
else if ((word & THREE_DIGITS_MASK) == THREE_DIGITS_MASK) {
temperature = (word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK) - ZERO * 111;
}
else if ((word & TWO_DIGITS_MASK) == TWO_DIGITS_MASK) {
temperature = (word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK) - ZERO * 11;
offset--;
}
else {
// #.##-
word = (word >>> 8) | (UNSAFE.getByte(offset++) << 24);
temperature = ZERO * 111 - ((word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK));
}
offset++;
results.put(cityOffset, cityNameSize, hashCode, (short) temperature);
}
return results.toStringMap();
}
private static final class MeasurementAggregation {
private long sum; private long sum;
private int count; private int count;
private short min; private short min;
private short max; private short max;
public MeasurementAggregation(long sum, int count, short min, short max) { public TemperatureAggregation(long sum, int count, short min, short max) {
this.sum = sum; this.sum = sum;
this.count = count; this.count = count;
this.min = min; this.min = min;
this.max = max; this.max = max;
} }
public void merge(MeasurementAggregation o) { public void merge(TemperatureAggregation o) {
if (o == null) { if (o == null) {
return; return;
} }
@ -183,6 +145,9 @@ public class CalculateAverage_zerninv {
} }
} }
private record Measurement(String station, TemperatureAggregation aggregation) {
}
private static final class MeasurementContainer { private static final class MeasurementContainer {
private static final int SIZE = 1024 * 16; private static final int SIZE = 1024 * 16;
@ -235,26 +200,26 @@ public class CalculateAverage_zerninv {
} }
} }
public Map<String, MeasurementAggregation> toStringMap() { public List<Measurement> measurements() {
var result = new HashMap<String, MeasurementAggregation>(); var result = new ArrayList<Measurement>(1000);
int count; int count;
for (int i = 0; i < SIZE; i++) { for (int i = 0; i < SIZE; i++) {
long ptr = this.address + i * ENTRY_SIZE; long ptr = this.address + i * ENTRY_SIZE;
count = UNSAFE.getInt(ptr + COUNT_OFFSET); count = UNSAFE.getInt(ptr + COUNT_OFFSET);
if (count != 0) { if (count != 0) {
var measurements = new MeasurementAggregation( var measurements = new TemperatureAggregation(
UNSAFE.getLong(ptr + SUM_OFFSET), UNSAFE.getLong(ptr + SUM_OFFSET),
count, count,
UNSAFE.getShort(ptr + MIN_OFFSET), UNSAFE.getShort(ptr + MIN_OFFSET),
UNSAFE.getShort(ptr + MAX_OFFSET)); UNSAFE.getShort(ptr + MAX_OFFSET));
var key = createString(UNSAFE.getLong(ptr + ADDRESS_OFFSET), UNSAFE.getByte(ptr + SIZE_OFFSET)); var key = createString(UNSAFE.getLong(ptr + ADDRESS_OFFSET), UNSAFE.getByte(ptr + SIZE_OFFSET));
result.put(key, measurements); result.add(new Measurement(key, measurements));
} }
} }
return result; return result;
} }
private boolean isEqual(long address, long address2, byte size) { private static boolean isEqual(long address, long address2, byte size) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
if (UNSAFE.getByte(address + i) != UNSAFE.getByte(address2 + i)) { if (UNSAFE.getByte(address + i) != UNSAFE.getByte(address2 + i)) {
return false; return false;
@ -271,4 +236,69 @@ public class CalculateAverage_zerninv {
return new String(arr); return new String(arr);
} }
} }
private static class TaskThread extends Thread {
private final MeasurementContainer container;
private final List<Long> begins;
private final List<Long> ends;
private TaskThread(MeasurementContainer container, int chunks) {
this.container = container;
this.begins = new ArrayList<>(chunks);
this.ends = new ArrayList<>(chunks);
}
public void addChunk(long begin, long end) {
begins.add(begin);
ends.add(end);
}
@Override
public void run() {
for (int i = 0; i < begins.size(); i++) {
calcForChunk(begins.get(i), ends.get(i));
}
}
public List<Measurement> measurements() {
return container.measurements();
}
private void calcForChunk(long offset, long end) {
long cityOffset;
int hashCode, temperature, word;
byte cityNameSize, b;
while (offset < end) {
cityOffset = offset;
hashCode = 0;
while ((b = UNSAFE.getByte(offset++)) != DELIMITER) {
hashCode = hashCode * 31 + b;
}
cityNameSize = (byte) (offset - cityOffset - 1);
word = UNSAFE.getInt(offset);
offset += 4;
if ((word & TWO_NEGATIVE_DIGITS_MASK) == TWO_NEGATIVE_DIGITS_MASK) {
word >>>= 8;
temperature = ZERO * 11 - ((word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK));
}
else if ((word & THREE_DIGITS_MASK) == THREE_DIGITS_MASK) {
temperature = (word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK) - ZERO * 111;
}
else if ((word & TWO_DIGITS_MASK) == TWO_DIGITS_MASK) {
temperature = (word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK) - ZERO * 11;
offset--;
}
else {
// #.##-
word = (word >>> 8) | (UNSAFE.getByte(offset++) << 24);
temperature = ZERO * 111 - ((word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK));
}
offset++;
container.put(cityOffset, cityNameSize, hashCode, (short) temperature);
}
}
}
} }