337 lines
12 KiB
Java
337 lines
12 KiB
Java
|
/*
|
||
|
* Copyright 2023 The original authors
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
package dev.morling.onebrc;
|
||
|
|
||
|
import java.io.*;
|
||
|
import java.nio.MappedByteBuffer;
|
||
|
import java.nio.channels.FileChannel;
|
||
|
import java.nio.charset.*;
|
||
|
import java.util.*;
|
||
|
import java.util.concurrent.*;
|
||
|
|
||
|
public class CalculateAverage_couragelee {
|
||
|
private static class Temperature {
|
||
|
private int cnt = 0;
|
||
|
|
||
|
private double sum = 0;
|
||
|
|
||
|
private double min;
|
||
|
|
||
|
private double max;
|
||
|
|
||
|
public Temperature(String tempStr) {
|
||
|
double temp = Double.parseDouble(tempStr);
|
||
|
this.min = temp;
|
||
|
this.max = temp;
|
||
|
this.sum = temp;
|
||
|
this.cnt++;
|
||
|
}
|
||
|
|
||
|
public Temperature(int cnt, double sum, double min, double max) {
|
||
|
this.cnt = cnt;
|
||
|
this.sum = sum;
|
||
|
this.min = min;
|
||
|
this.max = max;
|
||
|
}
|
||
|
|
||
|
public Temperature addRecord(String tempStr) {
|
||
|
double temp = Double.parseDouble(tempStr);
|
||
|
Temperature newTemp = new Temperature(this.cnt, this.sum, this.min, this.max);
|
||
|
newTemp.min = Math.min(temp, newTemp.min);
|
||
|
newTemp.max = Math.max(temp, newTemp.max);
|
||
|
newTemp.sum += temp;
|
||
|
newTemp.cnt++;
|
||
|
return newTemp;
|
||
|
}
|
||
|
|
||
|
public Temperature merge(Temperature newValue) {
|
||
|
Temperature oldTemp = new Temperature(this.cnt, this.sum, this.min, this.max);
|
||
|
oldTemp.min = Math.min(newValue.min, oldTemp.min);
|
||
|
oldTemp.max = Math.max(newValue.max, oldTemp.max);
|
||
|
oldTemp.sum += newValue.sum;
|
||
|
oldTemp.cnt += newValue.cnt;
|
||
|
return oldTemp;
|
||
|
}
|
||
|
|
||
|
public void update(String tempStr) {
|
||
|
double temp = parseDouble(tempStr);
|
||
|
this.min = Math.min(temp, this.min);
|
||
|
this.max = Math.max(temp, this.max);
|
||
|
this.sum += temp;
|
||
|
this.cnt++;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public String toString() {
|
||
|
return STR."\{min}/\{Math.round((sum / cnt) * 10.0) / 10.0}/\{max}";
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private static final String FILE_PATH = "./measurements.txt";
|
||
|
|
||
|
// 并行任务的数量
|
||
|
public static final int CONCURRENT_NUM = 20;
|
||
|
|
||
|
private static FileChannel fc;
|
||
|
private static long fcSize;
|
||
|
|
||
|
private static int segmentSize;
|
||
|
|
||
|
private static Map<String, Temperature> temperatureMap;
|
||
|
|
||
|
// 需要拼接的行信息
|
||
|
private static Map<String, byte[]> tempBytesMap = new ConcurrentHashMap<>();
|
||
|
|
||
|
// 缓存double解析数据
|
||
|
private static Map<String, Double> doubleCache;
|
||
|
|
||
|
public static void main(String[] args) throws IOException, InterruptedException, ExecutionException {
|
||
|
// 初始化
|
||
|
File file = new File(FILE_PATH);
|
||
|
fc = new RandomAccessFile(file, "r").getChannel();
|
||
|
fcSize = fc.size();
|
||
|
segmentSize = (int) Math.ceil((double) fcSize / CONCURRENT_NUM);
|
||
|
|
||
|
calculate();
|
||
|
|
||
|
String resStr = temperatureMap.toString();
|
||
|
System.out.println(resStr);
|
||
|
}
|
||
|
|
||
|
private static void calculate() throws IOException, InterruptedException, ExecutionException {
|
||
|
ThreadPoolExecutor executor = new ThreadPoolExecutor(CONCURRENT_NUM, CONCURRENT_NUM, 0, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>());
|
||
|
|
||
|
temperatureMap = new ConcurrentSkipListMap<>();
|
||
|
preHeatDoubleCache();
|
||
|
|
||
|
List<Future<Map<String, Temperature>>> res = new ArrayList<>();
|
||
|
long startPos = 0;
|
||
|
if (fcSize < 1000000) {
|
||
|
Future<Map<String, Temperature>> partRes = executor.submit(new Task(startPos, fcSize));
|
||
|
Map<String, Temperature> map = partRes.get();
|
||
|
temperatureMap.putAll(map);
|
||
|
}
|
||
|
else {
|
||
|
while (true) {
|
||
|
if (startPos + segmentSize >= fcSize) {
|
||
|
Future<Map<String, Temperature>> partRes = executor.submit(new Task(startPos, fcSize - startPos));
|
||
|
res.add(partRes);
|
||
|
break;
|
||
|
}
|
||
|
else {
|
||
|
Future<Map<String, Temperature>> partRes = executor.submit(new Task(startPos, segmentSize));
|
||
|
res.add(partRes);
|
||
|
startPos += segmentSize;
|
||
|
}
|
||
|
}
|
||
|
// 合并结果
|
||
|
for (Future<Map<String, Temperature>> future : res) {
|
||
|
Map<String, Temperature> stringTemperatureMap = future.get();
|
||
|
for (Map.Entry<String, Temperature> entry : stringTemperatureMap.entrySet()) {
|
||
|
String station = entry.getKey();
|
||
|
Temperature value = entry.getValue();
|
||
|
temperatureMap.merge(station, value, (oldValue, newValue) -> oldValue.merge(newValue));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
executor.shutdown();
|
||
|
executor.awaitTermination(10, TimeUnit.MINUTES);
|
||
|
|
||
|
// 处理拼接的行信息,不超过总并发数,顺序处理
|
||
|
for (Map.Entry<String, byte[]> entry : tempBytesMap.entrySet()) {
|
||
|
String key = entry.getKey();
|
||
|
if (key.startsWith("E")) {
|
||
|
continue;
|
||
|
}
|
||
|
byte[] part1 = entry.getValue();
|
||
|
byte[] part2 = tempBytesMap.getOrDefault("E" + key, new byte[0]);
|
||
|
byte[] bytes = new byte[part1.length + part2.length];
|
||
|
System.arraycopy(part1, 0, bytes, 0, part1.length);
|
||
|
System.arraycopy(part2, 0, bytes, part1.length, part2.length);
|
||
|
String[] lines = convertToString1(bytes, 0, bytes.length - 1);
|
||
|
for (String line : lines) {
|
||
|
try {
|
||
|
handleRecordConcurrently(line);
|
||
|
}
|
||
|
catch (Exception e) {
|
||
|
e.printStackTrace();
|
||
|
System.out.println(line);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private static class Task implements Callable<Map<String, Temperature>> {
|
||
|
private long startPos;
|
||
|
private long size;
|
||
|
|
||
|
public Task(long startPos, long size) throws IOException {
|
||
|
this.startPos = startPos;
|
||
|
this.size = size;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public Map<String, Temperature> call() throws Exception {
|
||
|
Map<String, Temperature> map = new HashMap<>(10000);
|
||
|
try {
|
||
|
// 1亿个byte
|
||
|
boolean firstRowHandled = false;
|
||
|
|
||
|
MappedByteBuffer buffer = fc.map(FileChannel.MapMode.READ_ONLY, startPos, size);
|
||
|
byte[] lastLastRowBytes = null;
|
||
|
while (buffer.hasRemaining()) {
|
||
|
byte[] bytes = new byte[10000];
|
||
|
// 先拼上上一次的最后一行
|
||
|
int startIndex = 0;
|
||
|
if (lastLastRowBytes != null) {
|
||
|
for (byte lastLastRowByte : lastLastRowBytes) {
|
||
|
bytes[startIndex++] = lastLastRowByte;
|
||
|
}
|
||
|
}
|
||
|
int readLength = Math.min(buffer.remaining(), 10000 - startIndex);
|
||
|
lastLastRowBytes = null;
|
||
|
buffer.get(bytes, startIndex, readLength);
|
||
|
// 处理第一行
|
||
|
int firstIndex = 0;
|
||
|
if (!firstRowHandled) {
|
||
|
firstRowHandled = true;
|
||
|
if (startPos == 0) {
|
||
|
// 全文第一行,不要特殊处理
|
||
|
}
|
||
|
else {
|
||
|
while (bytes[firstIndex] != 10) {
|
||
|
firstIndex++;
|
||
|
}
|
||
|
byte[] firstRowBytes = Arrays.copyOfRange(bytes, 0, firstIndex + 1);
|
||
|
tempBytesMap.put("E" + String.valueOf(startPos - 1), firstRowBytes);
|
||
|
firstIndex++;
|
||
|
}
|
||
|
}
|
||
|
// 分段的最后一行(可能不完整)
|
||
|
int lastIndex = startIndex + readLength - 1;
|
||
|
|
||
|
while (bytes[lastIndex] != 10) {
|
||
|
lastIndex--;
|
||
|
}
|
||
|
if (lastIndex == startIndex + readLength - 1) {
|
||
|
// 分段的最后一行是完整的
|
||
|
}
|
||
|
else {
|
||
|
// 暂存一下
|
||
|
lastLastRowBytes = Arrays.copyOfRange(bytes, lastIndex + 1, startIndex + readLength);
|
||
|
}
|
||
|
|
||
|
// [firstIndex, lastIndex] 这之间的数据是完整的多行数据
|
||
|
String[] lines = convertToString1(bytes, firstIndex, lastIndex);
|
||
|
handleRecord(map, lines);
|
||
|
}
|
||
|
// 处理最后一行
|
||
|
if (lastLastRowBytes != null) {
|
||
|
tempBytesMap.put(String.valueOf(startPos + size - 1), Arrays.copyOf(lastLastRowBytes, lastLastRowBytes.length));
|
||
|
}
|
||
|
else {
|
||
|
tempBytesMap.put(String.valueOf(startPos + size - 1), new byte[0]);
|
||
|
}
|
||
|
}
|
||
|
catch (Exception e) {
|
||
|
e.printStackTrace();
|
||
|
}
|
||
|
return map;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private static void handleRecord(Map<String, Temperature> map, String[] records) {
|
||
|
if (records == null || records.length == 0) {
|
||
|
return;
|
||
|
}
|
||
|
for (String record : records) {
|
||
|
if ("".equals(record)) {
|
||
|
continue;
|
||
|
}
|
||
|
int index = record.indexOf(";");
|
||
|
String station = record.substring(0, index);
|
||
|
String stationValue = record.substring(index + 1);
|
||
|
Temperature temperature = map.get(station);
|
||
|
if (temperature == null) {
|
||
|
temperature = new Temperature(stationValue);
|
||
|
map.put(station, temperature);
|
||
|
}
|
||
|
else {
|
||
|
temperature.update(stationValue);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private static void handleRecordConcurrently(String record) {
|
||
|
if (record.isEmpty()) {
|
||
|
return;
|
||
|
}
|
||
|
String[] split = record.split(";");
|
||
|
String station = split[0];
|
||
|
String stationValue = split[1];
|
||
|
// temperatureMap中只能新增值,不会删除
|
||
|
if (temperatureMap.get(station) == null) {
|
||
|
if (temperatureMap.putIfAbsent(station, new Temperature(stationValue)) != null) {
|
||
|
// 插入失败
|
||
|
temperatureMap.computeIfPresent(station, (key, oldValue) -> oldValue.addRecord(stationValue));
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
// 已经有值了
|
||
|
temperatureMap.computeIfPresent(station, (key, oldValue) -> oldValue.addRecord(stationValue));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/**
|
||
|
*
|
||
|
* @param bytes
|
||
|
* @param start 起始索引,包含
|
||
|
* @param end 结束索引,包含
|
||
|
* @return
|
||
|
*/
|
||
|
private static String[] convertToString1(byte[] bytes, int start, int end) {
|
||
|
if (bytes == null || bytes.length == 0) {
|
||
|
return new String[0];
|
||
|
}
|
||
|
String s = new String(bytes, start, (end - start + 1), StandardCharsets.UTF_8);
|
||
|
String[] split = s.split("\n");
|
||
|
return split;
|
||
|
}
|
||
|
|
||
|
// 预热-99.9到99.9之间的数,且始终包含一位小数
|
||
|
private static void preHeatDoubleCache() {
|
||
|
doubleCache = new ConcurrentHashMap<>();
|
||
|
for (int i = -99; i < 99; i++) {
|
||
|
for (int j = 0; j < 10; j++) {
|
||
|
String stand = String.valueOf(i);
|
||
|
String v = stand + "." + j;
|
||
|
doubleCache.put(v, Double.parseDouble(v));
|
||
|
}
|
||
|
}
|
||
|
for (int i = 0; i < 10; i++) {
|
||
|
String stand = "-0";
|
||
|
String v = stand + "." + i;
|
||
|
doubleCache.put(v, Double.parseDouble(v));
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
private static double parseDouble(String tempStr) {
|
||
|
return doubleCache.get(tempStr);
|
||
|
}
|
||
|
}
|