tonivade improved solution (#582)

* tonivade improved not using HashMap

* use java 21.0.2

* same hash same station

* remove unused parameter in sameSation

* use length too

* refactor parallelization

* use parallel GC

* refactor

* refactor
This commit is contained in:
Antonio Muñoz 2024-01-25 23:07:20 +01:00 committed by GitHub
parent 0bd1675571
commit 65d2c1b0c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 153 additions and 135 deletions

View File

@ -15,5 +15,5 @@
# limitations under the License. # limitations under the License.
# #
JAVA_OPTS="-Xmx1G -Xms1G -XX:+AlwaysPreTouch --enable-preview" JAVA_OPTS="-Xmx1G -Xms1G -XX:+AlwaysPreTouch -XX:+UseParallelGC -XX:-UseCompressedOops --enable-preview"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_tonivade java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_tonivade

View File

@ -17,4 +17,4 @@
# Uncomment below to use sdk # Uncomment below to use sdk
source "$HOME/.sdkman/bin/sdkman-init.sh" source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.1-tem 1>&2 sdk use java 21.0.2-tem 1>&2

View File

@ -15,9 +15,6 @@
*/ */
package dev.morling.onebrc; package dev.morling.onebrc;
import static java.util.Comparator.comparing;
import static java.util.stream.Collectors.joining;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
@ -26,9 +23,8 @@ import java.nio.charset.StandardCharsets;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.nio.file.StandardOpenOption; import java.nio.file.StandardOpenOption;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.StructuredTaskScope; import java.util.concurrent.StructuredTaskScope;
import java.util.concurrent.StructuredTaskScope.Subtask; import java.util.concurrent.StructuredTaskScope.Subtask;
@ -37,32 +33,16 @@ public class CalculateAverage_tonivade {
private static final String FILE = "./measurements.txt"; private static final String FILE = "./measurements.txt";
private static final int EOL = 10; private static final int MIN_CHUNK_SIZE = 1024;
private static final int MINUS = 45; private static final int MAX_NAME_LENGTH = 128;
private static final int SEMICOLON = 59; private static final int MAX_TEMP_LENGTH = 8;
public static void main(String[] args) throws IOException, InterruptedException, ExecutionException { public static void main(String[] args) throws IOException, InterruptedException, ExecutionException {
var result = readFile(); System.out.println(readFile());
var measurements = getMeasurements(result);
System.out.println(measurements);
} }
static record PartialResult(int end, Map<Name, Station> map) { private static Map<String, Station> readFile() throws IOException, InterruptedException, ExecutionException {
Map<String, Station> result = new TreeMap<>();
void merge(Map<Name, Station> result) {
map.forEach((name, station) -> result.merge(name, station, Station::merge));
}
}
private static String getMeasurements(Map<Name, Station> result) {
return result.values().stream().sorted(comparing(Station::getName))
.map(Station::asString).collect(joining(", ", "{", "}"));
}
private static Map<Name, Station> readFile() throws IOException, InterruptedException, ExecutionException {
Map<Name, Station> result = HashMap.newHashMap(10_000);
try (var channel = FileChannel.open(Paths.get(FILE), StandardOpenOption.READ)) { try (var channel = FileChannel.open(Paths.get(FILE), StandardOpenOption.READ)) {
long consumed = 0; long consumed = 0;
long remaining = channel.size(); long remaining = channel.size();
@ -70,8 +50,11 @@ public class CalculateAverage_tonivade {
var buffer = channel.map( var buffer = channel.map(
MapMode.READ_ONLY, consumed, Math.min(remaining, Integer.MAX_VALUE)); MapMode.READ_ONLY, consumed, Math.min(remaining, Integer.MAX_VALUE));
if (buffer.remaining() <= 1024) { int chunks = Runtime.getRuntime().availableProcessors();
var partialResult = readChunk(buffer, 0, buffer.remaining()); int chunkSize = buffer.remaining() / chunks;
int leftover = buffer.remaining() % chunks;
if (chunkSize < MIN_CHUNK_SIZE) {
var partialResult = new Chunk(buffer, 0, buffer.remaining()).read();
consumed += partialResult.end(); consumed += partialResult.end();
remaining -= partialResult.end(); remaining -= partialResult.end();
@ -79,17 +62,12 @@ public class CalculateAverage_tonivade {
partialResult.merge(result); partialResult.merge(result);
} }
else { else {
var chunks = Runtime.getRuntime().availableProcessors();
var chunksSize = buffer.remaining() / chunks;
var leftover = buffer.remaining() % chunks;
try (var scope = new StructuredTaskScope.ShutdownOnFailure()) { try (var scope = new StructuredTaskScope.ShutdownOnFailure()) {
var tasks = new ArrayList<Subtask<PartialResult>>(chunks); var tasks = new ArrayList<Subtask<PartialResult>>(chunks);
for (int i = 0; i < chunks; i++) { for (int i = 0; i < chunks; i++) {
int start = i * chunksSize; int start = i * chunkSize;
int length = chunksSize + (i < chunks ? leftover : 0); int length = chunkSize + (i < chunks ? leftover : 0);
tasks.add(scope.fork(() -> readChunk( tasks.add(scope.fork(new Chunk(buffer, start, length)::read));
buffer, findStart(buffer, start), start + length)));
} }
scope.join(); scope.join();
scope.throwIfFailed(); scope.throwIfFailed();
@ -106,29 +84,26 @@ public class CalculateAverage_tonivade {
return result; return result;
} }
private static PartialResult readChunk(ByteBuffer buffer, int start, int end) { static final class Chunk {
final byte[] name = new byte[128];
final byte[] temp = new byte[8];
final Map<Name, Station> map = HashMap.newHashMap(1000);
int position = start;
while (position < end) {
int semicolon = readName(buffer, position, end - position, name);
if (semicolon < 0) {
break;
}
int endOfLine = readTemp(buffer, semicolon + 1, end - semicolon - 1, temp); private static final int EOL = 10;
if (endOfLine < 0) { private static final int MINUS = 45;
break; private static final int SEMICOLON = 59;
}
map.computeIfAbsent(new Name(name, semicolon - position), Station::new) final ByteBuffer buffer;
.add(parseTemp(temp, endOfLine - semicolon - 1)); final int start;
final int end;
// skip end of line final byte[] name = new byte[MAX_NAME_LENGTH];
position = endOfLine + 1; final byte[] temp = new byte[MAX_TEMP_LENGTH];
} final Stations stations = new Stations();
return new PartialResult(position, map);
int hash;
Chunk(ByteBuffer buffer, int start, int length) {
this.buffer = buffer;
this.start = findStart(buffer, start);
this.end = start + length;
} }
private static int findStart(ByteBuffer buffer, int start) { private static int findStart(ByteBuffer buffer, int start) {
@ -143,21 +118,48 @@ public class CalculateAverage_tonivade {
return start; return start;
} }
private static int readName(ByteBuffer buffer, int offset, int length, byte[] name) { PartialResult read() {
return readUntil(buffer, offset, length, name, SEMICOLON); int position = start;
while (position < end) {
int semicolon = readName(position, end - position);
if (semicolon < 0) {
break;
} }
private static int readTemp(ByteBuffer buffer, int offset, int length, byte[] percentage) { int endOfLine = readTemp(semicolon + 1, end - semicolon - 1);
return readUntil(buffer, offset, length, percentage, EOL); if (endOfLine < 0) {
break;
} }
private static int readUntil(ByteBuffer buffer, int offset, int length, byte[] array, int target) { stations.find(name, semicolon - position, hash)
.add(parseTemp(temp, endOfLine - semicolon - 1));
// skip end of line
position = endOfLine + 1;
}
return new PartialResult(position, stations.buckets);
}
private int readName(int offset, int length) {
hash = 1;
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
byte b = buffer.get(i + offset); byte b = buffer.get(i + offset);
if (b == target) { if (b == SEMICOLON) {
return i + offset; return i + offset;
} }
array[i] = b; name[i] = b;
hash = 31 * hash + b;
}
return -1;
}
private int readTemp(int offset, int length) {
for (int i = 0; i < length; i++) {
byte b = buffer.get(i + offset);
if (b == EOL) {
return i + offset;
}
temp[i] = b;
} }
return -1; return -1;
} }
@ -188,50 +190,48 @@ public class CalculateAverage_tonivade {
private static int toInt(byte c) { private static int toInt(byte c) {
return c - 48; return c - 48;
} }
static final class Name {
private final byte[] value;
Name(byte[] source, int length) {
value = new byte[length];
System.arraycopy(source, 0, value, 0, length);
} }
@Override static final class Stations {
public int hashCode() {
return Arrays.hashCode(value);
}
@Override private static final int NUMBER_OF_BUCKETS = 1000;
public boolean equals(Object obj) { private static final int BUCKET_SIZE = 50;
if (obj instanceof Name other) {
return Arrays.equals(value, other.value);
}
return false;
}
@Override final Station[][] buckets = new Station[NUMBER_OF_BUCKETS][BUCKET_SIZE];
public String toString() {
return new String(value, StandardCharsets.UTF_8); Station find(byte[] name, int length, int hash) {
var bucket = buckets[Math.abs(hash % NUMBER_OF_BUCKETS)];
for (int i = 0; i < BUCKET_SIZE; i++) {
if (bucket[i] == null) {
bucket[i] = new Station(name, length, hash);
return bucket[i];
}
else if (bucket[i].sameName(length, hash)) {
return bucket[i];
}
}
throw new IllegalStateException("no more space left");
} }
} }
static final class Station { static final class Station {
private final Name name; private final byte[] name;
private final int hash;
private int min = Integer.MAX_VALUE; private int min = 1000;
private int max = Integer.MIN_VALUE; private int max = -1000;
private int sum; private int sum;
private long count; private long count;
Station(Name name) { Station(byte[] source, int length, int hash) {
this.name = name; name = new byte[length];
System.arraycopy(source, 0, name, 0, length);
this.hash = hash;
} }
String getName() { String getName() {
return name.toString(); return new String(name, StandardCharsets.UTF_8);
} }
void add(int value) { void add(int value) {
@ -249,8 +249,13 @@ public class CalculateAverage_tonivade {
return this; return this;
} }
String asString() { @Override
return name + "=" + toDouble(min) + "/" + round(mean()) + "/" + toDouble(max); public String toString() {
return toDouble(min) + "/" + round(mean()) + "/" + toDouble(max);
}
boolean sameName(int length, int hash) {
return name.length == length && this.hash == hash;
} }
private double mean() { private double mean() {
@ -265,4 +270,17 @@ public class CalculateAverage_tonivade {
return Math.round(value * 10.) / 10.; return Math.round(value * 10.) / 10.;
} }
} }
static record PartialResult(int end, Station[][] stations) {
void merge(Map<String, Station> result) {
for (Station[] bucket : stations) {
for (Station station : bucket) {
if (station != null) {
result.merge(station.getName(), station, Station::merge);
}
}
}
}
}
} }