improve hard disk access locality, another 8% (#591)

* improve hard disk access locality, another 8%

* add some comments & credit

* fixed format
This commit is contained in:
Van Phu DO 2024-01-27 22:54:43 +09:00 committed by GitHub
parent 5092eb44d1
commit c228633b57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -28,18 +28,21 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import sun.misc.Unsafe; import sun.misc.Unsafe;
public class CalculateAverage_abeobk { public class CalculateAverage_abeobk {
private static final boolean SHOW_ANALYSIS = false; private static final boolean SHOW_ANALYSIS = false;
private static final int CPU_CNT = Runtime.getRuntime().availableProcessors();
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static final int BUCKET_SIZE = 1 << 16; private static final int BUCKET_SIZE = 1 << 16;
private static final int BUCKET_MASK = BUCKET_SIZE - 1; private static final int BUCKET_MASK = BUCKET_SIZE - 1;
private static final int MAX_STR_LEN = 100; private static final int MAX_STR_LEN = 100;
private static final int MAX_STATIONS = 10000; private static final int MAX_STATIONS = 10000;
private static final long CHUNK_SZ = 1 << 22; // 4MB chunk
private static final Unsafe UNSAFE = initUnsafe(); private static final Unsafe UNSAFE = initUnsafe();
private static final long[] HASH_MASKS = new long[]{ private static final long[] HASH_MASKS = new long[]{
0x0L, 0x0L,
@ -52,6 +55,11 @@ public class CalculateAverage_abeobk {
0xffffffffffffffL, 0xffffffffffffffL,
0xffffffffffffffffL, }; 0xffffffffffffffffL, };
private static AtomicInteger chunk_id = new AtomicInteger(0);
private static int chunk_cnt;
private static long start_addr, end_addr;
private static Stat[][] all_res;
private static final void debug(String s, Object... args) { private static final void debug(String s, Object... args) {
System.out.println(String.format(s, args)); System.out.println(String.format(s, args));
} }
@ -153,20 +161,6 @@ public class CalculateAverage_abeobk {
} }
} }
// split into chunks
static long[] slice(long start_addr, long end_addr, long chunk_size, int cpu_cnt) {
long[] ptrs = new long[cpu_cnt + 1];
ptrs[0] = start_addr;
for (int i = 1; i < cpu_cnt; i++) {
long addr = start_addr + i * chunk_size;
while (addr < end_addr && UNSAFE.getByte(addr++) != '\n')
;
ptrs[i] = Math.min(addr, end_addr);
}
ptrs[cpu_cnt] = end_addr;
return ptrs;
}
// idea from royvanrijn // idea from royvanrijn
static final long getSemiPosCode(final long word) { static final long getSemiPosCode(final long word) {
long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;; long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;;
@ -189,15 +183,38 @@ public class CalculateAverage_abeobk {
return (short) ((abs_val ^ signed) - signed); return (short) ((abs_val ^ signed) - signed);
} }
// Thread pool worker
static final class Worker extends Thread {
final int thread_id;
Worker(int i) {
thread_id = i;
this.start();
}
@Override
public void run() {
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
int cnt = 0;
int id;
int cls = 0;
// process in small chunk to maintain disk locality (artsiomkorzun trick)
// but keep going instead of merging
while ((id = chunk_id.getAndIncrement()) < chunk_cnt) {
long addr = start_addr + id * CHUNK_SZ;
long end = Math.min(addr + CHUNK_SZ, end_addr);
// adjust start
if (id > 0) {
while (UNSAFE.getByte(addr++) != '\n')
;
}
// parse loop
// optimize for contest // optimize for contest
// save as much slow memory access as possible // save as much slow memory access as possible
// about 50% key < 8chars, 25% key bettween 8-10 chars // about 50% key < 8chars, 25% key bettween 8-10 chars
// keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2... // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
static final Node[] parse(int thread_id, long start, long end) {
int cls = 0;
long addr = start;
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
// parse loop
while (addr < end) { while (addr < end) {
long row_addr = addr; long row_addr = addr;
@ -220,6 +237,7 @@ public class CalculateAverage_abeobk {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, tail, val); map[bucket] = new Node(row_addr, tail, val);
cnt++;
break; break;
} }
if (node.tail == tail) { if (node.tail == tail) {
@ -252,6 +270,7 @@ public class CalculateAverage_abeobk {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, val); map[bucket] = new Node(row_addr, word0, tail, val);
cnt++;
break; break;
} }
if (node.word0 == word0 && node.tail == tail) { if (node.word0 == word0 && node.tail == tail) {
@ -266,7 +285,6 @@ public class CalculateAverage_abeobk {
} }
// why not going for more? tested, slower // why not going for more? tested, slower
long hash = word0; long hash = word0;
while (semipos_code == 0) { while (semipos_code == 0) {
hash ^= word; hash ^= word;
@ -290,6 +308,7 @@ public class CalculateAverage_abeobk {
var node = map[bucket]; var node = map[bucket];
if (node == null) { if (node == null) {
map[bucket] = new Node(row_addr, word0, tail, val); map[bucket] = new Node(row_addr, word0, tail, val);
cnt++;
break; break;
} }
if (node.contentEquals(row_addr, word0, tail, keylen)) { if (node.contentEquals(row_addr, word0, tail, keylen)) {
@ -301,11 +320,21 @@ public class CalculateAverage_abeobk {
cls++; cls++;
} }
} }
}
if (SHOW_ANALYSIS) { if (SHOW_ANALYSIS) {
debug("Thread %d collision = %d", thread_id, cls); debug("Thread %d collision = %d", thread_id, cls);
} }
return map;
Stat[] stats = new Stat[cnt];
int i = 0;
for (var node : map) {
if (node != null) {
stats[i++] = new Stat(node);
}
}
all_res[thread_id] = stats;
}
} }
// thomaswue trick // thomaswue trick
@ -329,44 +358,32 @@ public class CalculateAverage_abeobk {
return; return;
} }
try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
long file_size = file.size(); long file_size = file.size();
long end_addr = start_addr + file_size; start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
end_addr = start_addr + file_size;
// only use all cpus on large file // only use all cpus on large file
int cpu_cnt = file_size < 1e6 ? 1 : Runtime.getRuntime().availableProcessors(); int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT;
long chunk_size = Math.ceilDiv(file_size, cpu_cnt); chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ);
all_res = new Stat[cpu_cnt][];
// processing List<Worker> workers = IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList();
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt); for (var w : workers)
w.join();
List<List<Stat>> maps = IntStream.range(0, cpu_cnt)
.mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1]))
.map(map -> {
List<Stat> stats = new ArrayList<>();
for (var node : map) {
if (node == null)
continue;
stats.add(new Stat(node));
}
return stats;
})
.parallel()
.toList();
// collect all results
TreeMap<String, Stat> ms = new TreeMap<>(); TreeMap<String, Stat> ms = new TreeMap<>();
for (var stats : maps) { for (var res : all_res) {
for (var s : stats) { for (var s : res) {
var stat = ms.putIfAbsent(s.key, s); var stat = ms.putIfAbsent(s.key, s);
if (stat != null) if (stat != null)
stat.node.merge(s.node); stat.node.merge(s.node);
} }
} }
// print result // print output
System.out.println(ms); System.out.println(ms);
System.out.close(); System.out.close();
} }
} }
}