Read file in multiple threads and String to Text (#427)

* - Read file in multiple threads if available: 17" -> 15" locally
- Changed String to BytesText with cache: 12" locally

* - Fixed bug
- BytesText to Text
- More checks when reading the file

* - Combining measurements should be thread safe
- More readability changes
This commit is contained in:
Anthony Goubard 2024-01-16 22:10:38 +01:00 committed by GitHub
parent 7bd2df7c59
commit e4b717e1a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 200 additions and 62 deletions

View File

@ -15,5 +15,5 @@
# limitations under the License. # limitations under the License.
# #
JAVA_OPTS="-Xmx2G" JAVA_OPTS=""
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_japplis $* java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_japplis $*

View File

@ -41,7 +41,9 @@ import java.util.concurrent.*;
* - Replaced compute lambda call with synchronized(city.intern()): 43" (due to intern()) * - Replaced compute lambda call with synchronized(city.intern()): 43" (due to intern())
* - Removed BufferedInputStream and replaced Measurement with IntSummaryStatistics (thanks davecom): still 23" but cleaner code * - Removed BufferedInputStream and replaced Measurement with IntSummaryStatistics (thanks davecom): still 23" but cleaner code
* - Execute same code on 1BRC server: 41" * - Execute same code on 1BRC server: 41"
* - One HashMap per thread: 17" locally * - One HashMap per thread: 17" locally (12" on 1BRC server)
* - Read file in multiple threads if available and
* - Changed String to (byte[]) Text with cache: 18" locally (but 8" -> 5" on laptop)
* *
* @author Anthony Goubard - Japplis * @author Anthony Goubard - Japplis
*/ */
@ -53,64 +55,113 @@ public class CalculateAverage_japplis {
private int precision = -1; private int precision = -1;
private int precisionLimitTenth; private int precisionLimitTenth;
private long fileSize;
private Map<String, IntSummaryStatistics> cityMeasurementMap = new ConcurrentHashMap<>(); private Map<Text, IntSummaryStatistics> cityMeasurementMap = new ConcurrentHashMap<>(10_000);
private List<Byte> previousBlockLastLine = new ArrayList<>(); private List<Byte> previousBlockLastLine = new ArrayList<>();
private Semaphore readFileLock = new Semaphore(MAX_COMPUTE_THREADS); private Semaphore readFileLock = new Semaphore(MAX_COMPUTE_THREADS);
private Queue<ByteArray> bufferPool = new ConcurrentLinkedQueue<>();
private void parseTemperatures(File measurementsFile) throws Exception { private void parseTemperatures(File measurementsFile) throws Exception {
try (InputStream measurementsFileIS = new FileInputStream(measurementsFile)) { fileSize = measurementsFile.length();
int readCount = BUFFER_SIZE; int blockIndex = 0;
int totalBlocks = (int) (fileSize / BUFFER_SIZE) + 1;
ExecutorService threadPool = Executors.newFixedThreadPool(MAX_COMPUTE_THREADS); ExecutorService threadPool = Executors.newFixedThreadPool(MAX_COMPUTE_THREADS);
List<Future> parseBlockTasks = new ArrayList<>(); List<Future> parseBlockTasks = new ArrayList<>();
while (readCount > 0) {
byte[] buffer = new byte[BUFFER_SIZE];
readCount = measurementsFileIS.read(buffer);
if (readCount > 0) {
readFileLock.acquire(); // Wait if all threads are busy
// Process the block in a thread while the main thread continues to read the file while (blockIndex < totalBlocks) {
Future parseBlockTask = threadPool.submit(parseTemperaturesBlock(buffer, readCount)); int availableReadThreads = Math.min(readFileLock.availablePermits(), totalBlocks - blockIndex);
if (availableReadThreads == 0) {
readFileLock.acquire(); // No need to loop in the 'while' if all threads are busy
readFileLock.release();
}
List<Future<ByteArray>> readBlockTasks = new ArrayList<>();
for (int i = 0; i < availableReadThreads; i++) {
readFileLock.acquire(); // Wait if all threads are busy
Callable<ByteArray> blockReader = readBlock(measurementsFile, blockIndex);
Future<ByteArray> readBlockTask = threadPool.submit(blockReader);
readBlockTasks.add(readBlockTask);
blockIndex++;
}
for (Future<ByteArray> readBlockTask : readBlockTasks) {
ByteArray buffer = readBlockTask.get();
if (buffer.array().length > 0) {
int startIndex = handleSplitLine(buffer.array());
readFileLock.acquire(); // Wait if all threads are busy
Runnable blockParser = parseTemperaturesBlock(buffer, startIndex);
Future parseBlockTask = threadPool.submit(blockParser);
parseBlockTasks.add(parseBlockTask); parseBlockTasks.add(parseBlockTask);
} }
} }
for (Future parseBlockTask : parseBlockTasks) // Wait for all tasks to finish }
for (Future parseBlockTask : parseBlockTasks) { // Wait for all tasks to finish
parseBlockTask.get(); parseBlockTask.get();
}
threadPool.shutdownNow(); threadPool.shutdownNow();
} }
private Callable<ByteArray> readBlock(File measurementsFile, long blockIndex) {
return () -> {
long fileIndex = blockIndex * BUFFER_SIZE;
if (fileIndex >= fileSize) {
readFileLock.release();
return new ByteArray(0);
}
try (InputStream measurementsFileIS = new FileInputStream(measurementsFile)) {
if (fileIndex > 0) {
long skipped = measurementsFileIS.skip(fileIndex);
while (skipped != fileIndex) {
skipped += measurementsFileIS.skip(fileIndex - skipped);
}
}
long bufferSize = Math.min(BUFFER_SIZE, fileSize - fileIndex);
ByteArray buffer = bufferSize == BUFFER_SIZE ? bufferPool.poll() : new ByteArray((int) bufferSize);
if (buffer == null) {
buffer = new ByteArray(BUFFER_SIZE);
}
int totalRead = measurementsFileIS.read(buffer.array(), 0, (int) bufferSize);
while (totalRead < bufferSize) {
byte[] extraBuffer = new byte[(int) (bufferSize - totalRead)];
int readCount = measurementsFileIS.read(extraBuffer);
System.arraycopy(extraBuffer, 0, buffer.array(), totalRead, readCount);
totalRead += readCount;
}
readFileLock.release();
return buffer;
}
};
} }
private Runnable parseTemperaturesBlock(byte[] buffer, int readCount) { private Runnable parseTemperaturesBlock(ByteArray buffer, int startIndex) {
int startIndex = handleSplitLine(buffer, readCount);
Runnable countAverageRun = () -> { Runnable countAverageRun = () -> {
int bufferIndex = startIndex; int bufferIndex = startIndex;
Map<String, IntSummaryStatistics> blockCityMeasurementMap = new HashMap<>(); Map<Text, IntSummaryStatistics> blockCityMeasurementMap = new HashMap<>(10_000);
Map<Integer, Text> textPool = new HashMap<>(10_000);
byte[] bufferArray = buffer.array();
try { try {
while (bufferIndex < readCount) { while (bufferIndex < bufferArray.length) {
bufferIndex = readNextLine(bufferIndex, buffer, blockCityMeasurementMap); bufferIndex = readNextLine(bufferIndex, bufferArray, blockCityMeasurementMap, textPool);
} }
} }
catch (ArrayIndexOutOfBoundsException ex) { catch (ArrayIndexOutOfBoundsException ex) {
// Done reading and parsing the buffer // Done reading and parsing the buffer
} }
if (bufferArray.length == BUFFER_SIZE)
bufferPool.add(buffer);
mergeBlockResults(blockCityMeasurementMap); mergeBlockResults(blockCityMeasurementMap);
readFileLock.release(); readFileLock.release();
}; };
return countAverageRun; return countAverageRun;
} }
private int handleSplitLine(byte[] buffer, int readCount) { private int handleSplitLine(byte[] buffer) {
int bufferIndex = readFirstLines(buffer); int bufferIndex = readFirstLines(buffer);
List<Byte> lastLine = new ArrayList<>(); // Store the last (partial) line of the block List<Byte> lastLine = new ArrayList<>(100); // Store the last (partial) line of the block
int tailIndex = readCount; int tailIndex = buffer.length;
if (tailIndex == buffer.length) {
byte car = buffer[--tailIndex]; byte car = buffer[--tailIndex];
while (car != '\n') { while (car != '\n') {
lastLine.add(0, car); lastLine.add(0, car);
car = buffer[--tailIndex]; car = buffer[--tailIndex];
} }
}
if (previousBlockLastLine.isEmpty()) { if (previousBlockLastLine.isEmpty()) {
previousBlockLastLine = lastLine; previousBlockLastLine = lastLine;
return bufferIndex; return bufferIndex;
@ -132,7 +183,7 @@ public class CalculateAverage_japplis {
for (int i = 0; i < splitLineBytes.length; i++) { for (int i = 0; i < splitLineBytes.length; i++) {
splitLineBytes[i] = previousBlockLastLine.get(i); splitLineBytes[i] = previousBlockLastLine.get(i);
} }
readNextLine(0, splitLineBytes, cityMeasurementMap); readNextLine(0, splitLineBytes, cityMeasurementMap, new HashMap<>());
return bufferIndex; return bufferIndex;
} }
@ -148,8 +199,9 @@ public class CalculateAverage_japplis {
int dotPos = bufferIndex; int dotPos = bufferIndex;
byte car = buffer[bufferIndex++]; byte car = buffer[bufferIndex++];
while (car != '\n') { while (car != '\n') {
if (car == '.') if (car == '.') {
dotPos = bufferIndex; dotPos = bufferIndex;
}
car = buffer[bufferIndex++]; car = buffer[bufferIndex++];
} }
precision = bufferIndex - dotPos - 1; precision = bufferIndex - dotPos - 1;
@ -158,40 +210,47 @@ public class CalculateAverage_japplis {
return startIndex; return startIndex;
} }
private int readNextLine(int bufferIndex, byte[] buffer, Map<String, IntSummaryStatistics> blockCityMeasurementMap) { private int readNextLine(int bufferIndex, byte[] buffer, Map<Text, IntSummaryStatistics> blockCityMeasurementMap, Map<Integer, Text> textPool) {
int startLineIndex = bufferIndex; int startLineIndex = bufferIndex;
while (buffer[bufferIndex] != ';') while (buffer[bufferIndex] != (byte) ';') {
bufferIndex++; bufferIndex++;
String city = new String(buffer, startLineIndex, bufferIndex - startLineIndex, StandardCharsets.UTF_8); }
// String city = new String(buffer, startLineIndex, bufferIndex - startLineIndex, StandardCharsets.UTF_8);
Text city = Text.getByteText(buffer, startLineIndex, bufferIndex - startLineIndex, textPool);
bufferIndex++; // skip ';' bufferIndex++; // skip ';'
int temperature = readTemperature(buffer, bufferIndex); int temperature = readTemperature(buffer, bufferIndex);
bufferIndex += precision + 3; // digit, dot and CR bufferIndex += precision + 3; // digit, dot and CR
if (temperature < 0) if (temperature < 0) {
bufferIndex++; bufferIndex++;
if (temperature <= -precisionLimitTenth || temperature >= precisionLimitTenth) }
if (temperature <= -precisionLimitTenth || temperature >= precisionLimitTenth) {
bufferIndex++; bufferIndex++;
}
addTemperature(city, temperature, blockCityMeasurementMap); addTemperature(city, temperature, blockCityMeasurementMap);
return bufferIndex; return bufferIndex;
} }
private int readTemperature(byte[] text, int measurementIndex) { private int readTemperature(byte[] buffer, int bufferIndex) {
boolean negative = text[measurementIndex] == '-'; boolean negative = buffer[bufferIndex] == (byte) '-';
if (negative) if (negative) {
measurementIndex++; bufferIndex++;
byte digitChar = text[measurementIndex++];
int temperature = 0;
while (digitChar != '\n') {
temperature = temperature * 10 + (digitChar - '0');
digitChar = text[measurementIndex++];
if (digitChar == '.')
digitChar = text[measurementIndex++];
} }
if (negative) byte digit = buffer[bufferIndex++];
int temperature = 0;
while (digit != (byte) '\n') {
temperature = temperature * 10 + (digit - (byte) '0');
digit = buffer[bufferIndex++];
if (digit == (byte) '.') { // Skip '.'
digit = buffer[bufferIndex++];
}
}
if (negative) {
temperature = -temperature; temperature = -temperature;
}
return temperature; return temperature;
} }
private void addTemperature(String city, int temperature, Map<String, IntSummaryStatistics> blockCityMeasurementMap) { private void addTemperature(Text city, int temperature, Map<Text, IntSummaryStatistics> blockCityMeasurementMap) {
IntSummaryStatistics measurement = blockCityMeasurementMap.get(city); IntSummaryStatistics measurement = blockCityMeasurementMap.get(city);
if (measurement == null) { if (measurement == null) {
measurement = new IntSummaryStatistics(); measurement = new IntSummaryStatistics();
@ -200,16 +259,20 @@ public class CalculateAverage_japplis {
measurement.accept(temperature); measurement.accept(temperature);
} }
private void mergeBlockResults(Map<String, IntSummaryStatistics> blockCityMeasurementMap) { private void mergeBlockResults(Map<Text, IntSummaryStatistics> blockCityMeasurementMap) {
blockCityMeasurementMap.forEach((city, measurement) -> { blockCityMeasurementMap.forEach((city, measurement) -> {
IntSummaryStatistics oldMeasurement = cityMeasurementMap.putIfAbsent(city, measurement); cityMeasurementMap.compute(city, (town, currentMeasurement) -> {
if (oldMeasurement != null) if (currentMeasurement == null) {
oldMeasurement.combine(measurement); return measurement;
}
currentMeasurement.combine(measurement);
return currentMeasurement;
});
}); });
} }
private void printTemperatureStatsByCity() { private void printTemperatureStatsByCity() {
Set<String> sortedCities = new TreeSet<>(cityMeasurementMap.keySet()); Set<Text> sortedCities = new TreeSet<>(cityMeasurementMap.keySet());
StringBuilder result = new StringBuilder(cityMeasurementMap.size() * 40); StringBuilder result = new StringBuilder(cityMeasurementMap.size() * 40);
result.append('{'); result.append('{');
sortedCities.forEach(city -> { sortedCities.forEach(city -> {
@ -217,7 +280,9 @@ public class CalculateAverage_japplis {
result.append(city); result.append(city);
result.append(getTemperatureStats(measurement)); result.append(getTemperatureStats(measurement));
}); });
if (!sortedCities.isEmpty()) {
result.delete(result.length() - 2, result.length()); result.delete(result.length() - 2, result.length());
}
result.append('}'); result.append('}');
String temperaturesByCity = result.toString(); String temperaturesByCity = result.toString();
System.out.println(temperaturesByCity); System.out.println(temperaturesByCity);
@ -242,9 +307,10 @@ public class CalculateAverage_japplis {
for (int i = temperatureAsText.length(); i < minCharacters; i++) { for (int i = temperatureAsText.length(); i < minCharacters; i++) {
temperatureAsText = temperature < 0 ? "-0" + temperatureAsText.substring(1) : "0" + temperatureAsText; temperatureAsText = temperature < 0 ? "-0" + temperatureAsText.substring(1) : "0" + temperatureAsText;
} }
resultBuilder.append(temperatureAsText.substring(0, temperatureAsText.length() - precision)); int dotPosition = temperatureAsText.length() - precision;
resultBuilder.append(temperatureAsText.substring(0, dotPosition));
resultBuilder.append('.'); resultBuilder.append('.');
resultBuilder.append(temperatureAsText.substring(temperatureAsText.length() - precision)); resultBuilder.append(temperatureAsText.substring(dotPosition));
} }
public static final void main(String... args) throws Exception { public static final void main(String... args) throws Exception {
@ -253,4 +319,76 @@ public class CalculateAverage_japplis {
cityTemperaturesCalculator.parseTemperatures(new File(measurementFile)); cityTemperaturesCalculator.parseTemperatures(new File(measurementFile));
cityTemperaturesCalculator.printTemperatureStatsByCity(); cityTemperaturesCalculator.printTemperatureStatsByCity();
} }
private class ByteArray {
private byte[] array;
private ByteArray(int size) {
array = new byte[size];
}
private byte[] array() {
return array;
}
}
private static class Text implements Comparable<Text> {
private final byte[] textBytes;
private final int hash;
private String text;
private Text(byte[] buffer, int startIndex, int length, int hash) {
textBytes = new byte[length];
this.hash = hash;
System.arraycopy(buffer, startIndex, textBytes, 0, length);
}
private static Text getByteText(byte[] buffer, int startIndex, int length, Map<Integer, Text> textPool) {
int hash = hashCode(buffer, startIndex, length);
Text textFromPool = textPool.get(hash);
if (textFromPool == null || !Arrays.equals(buffer, startIndex, startIndex + length, textFromPool.textBytes, 0, length)) {
Text newText = new Text(buffer, startIndex, length, hash);
textPool.put(hash, newText);
return newText;
}
return textFromPool;
}
private static int hashCode(byte[] buffer, int startIndex, int length) {
int hash = 31;
int endIndex = startIndex + length;
for (int i = startIndex; i < endIndex; i++) {
hash = 31 * hash + buffer[i];
}
return hash;
}
@Override
public int hashCode() {
return hash;
}
@Override
public boolean equals(Object other) {
return other != null &&
hashCode() == other.hashCode() &&
other instanceof Text &&
Arrays.equals(textBytes, ((Text) other).textBytes);
}
@Override
public int compareTo(Text other) {
return toString().compareTo(other.toString());
}
@Override
public String toString() {
if (text == null) {
text = new String(textBytes, StandardCharsets.UTF_8);
}
return text;
}
}
} }