Second submission to keep a bit of dignity (#581)
* Dmitry challenge * Dmitry submit 2. Use MemorySegment of FileChannle and Unsafe to read bytes from disk. 4 seconds speedup in local test from 20s to 16s.
This commit is contained in:
parent
65d2c1b0c9
commit
b20e7365e7
@ -17,4 +17,5 @@
|
|||||||
|
|
||||||
|
|
||||||
#JAVA_OPTS="-verbose:gc"
|
#JAVA_OPTS="-verbose:gc"
|
||||||
|
JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation"
|
||||||
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_bufistov $1 $2
|
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_bufistov $1 $2
|
||||||
|
@ -15,11 +15,17 @@
|
|||||||
*/
|
*/
|
||||||
package dev.morling.onebrc;
|
package dev.morling.onebrc;
|
||||||
|
|
||||||
|
import sun.misc.Unsafe;
|
||||||
|
|
||||||
import static java.lang.Math.toIntExact;
|
import static java.lang.Math.toIntExact;
|
||||||
|
|
||||||
|
import java.lang.foreign.Arena;
|
||||||
|
import java.lang.reflect.Field;
|
||||||
import java.nio.MappedByteBuffer;
|
import java.nio.MappedByteBuffer;
|
||||||
import java.nio.channels.FileChannel;
|
import java.nio.channels.FileChannel;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import java.time.Instant;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
@ -32,7 +38,27 @@ import java.io.FileInputStream;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.concurrent.Future;
|
import java.util.concurrent.Future;
|
||||||
|
|
||||||
class ResultRow {
|
class ByteArrayWrapper {
|
||||||
|
private final byte[] data;
|
||||||
|
|
||||||
|
public ByteArrayWrapper(byte[] data) {
|
||||||
|
this.data = data;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object other) {
|
||||||
|
return Arrays.equals(data, ((ByteArrayWrapper) other).data);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Arrays.hashCode(data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class CalculateAverage_bufistov {
|
||||||
|
|
||||||
|
static class ResultRow {
|
||||||
byte[] station;
|
byte[] station;
|
||||||
|
|
||||||
String stationString;
|
String stationString;
|
||||||
@ -57,9 +83,11 @@ class ResultRow {
|
|||||||
this.suma = value;
|
this.suma = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition) {
|
void setStation(long startPosition, long endPosition) {
|
||||||
this.station = new byte[endPosition - startPosition];
|
this.station = new byte[(int) (endPosition - startPosition)];
|
||||||
byteBuffer.slice(startPosition, station.length).get(this.station, 0, station.length);
|
for (int i = 0; i < this.station.length; ++i) {
|
||||||
|
this.station[i] = UNSAFE.getByte(startPosition + i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public String toString() {
|
public String toString() {
|
||||||
@ -71,7 +99,7 @@ class ResultRow {
|
|||||||
return Math.round(value * 10.0) / 10.0;
|
return Math.round(value * 10.0) / 10.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultRow update(long newValue) {
|
void update(long newValue) {
|
||||||
this.count += 1;
|
this.count += 1;
|
||||||
this.suma += newValue;
|
this.suma += newValue;
|
||||||
if (newValue < this.min) {
|
if (newValue < this.min) {
|
||||||
@ -80,7 +108,6 @@ class ResultRow {
|
|||||||
else if (newValue > this.max) {
|
else if (newValue > this.max) {
|
||||||
this.max = newValue;
|
this.max = newValue;
|
||||||
}
|
}
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultRow merge(ResultRow another) {
|
ResultRow merge(ResultRow another) {
|
||||||
@ -90,27 +117,9 @@ class ResultRow {
|
|||||||
this.max = Math.max(this.max, another.max);
|
this.max = Math.max(this.max, another.max);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
class ByteArrayWrapper {
|
|
||||||
private final byte[] data;
|
|
||||||
|
|
||||||
public ByteArrayWrapper(byte[] data) {
|
|
||||||
this.data = data;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
static class OpenHash {
|
||||||
public boolean equals(Object other) {
|
|
||||||
return Arrays.equals(data, ((ByteArrayWrapper) other).data);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return Arrays.hashCode(data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class OpenHash {
|
|
||||||
ResultRow[] data;
|
ResultRow[] data;
|
||||||
int dataSizeMask;
|
int dataSizeMask;
|
||||||
|
|
||||||
@ -150,26 +159,26 @@ class OpenHash {
|
|||||||
merge(station, value, hashByteArray(station));
|
merge(station, value, hashByteArray(station));
|
||||||
}
|
}
|
||||||
|
|
||||||
void merge(MappedByteBuffer byteBuffer, final int startPosition, final int endPosition, int hashValue, final long value) {
|
void merge(final long startPosition, long endPosition, int hashValue, long value) {
|
||||||
while (data[hashValue] != null && !equalsToStation(byteBuffer, startPosition, endPosition, data[hashValue].station)) {
|
while (data[hashValue] != null && !equalsToStation(startPosition, endPosition, data[hashValue].station)) {
|
||||||
hashValue += 1;
|
hashValue += 1;
|
||||||
hashValue &= dataSizeMask;
|
hashValue &= dataSizeMask;
|
||||||
}
|
}
|
||||||
if (data[hashValue] == null) {
|
if (data[hashValue] == null) {
|
||||||
data[hashValue] = new ResultRow(value);
|
data[hashValue] = new ResultRow(value);
|
||||||
data[hashValue].setStation(byteBuffer, startPosition, endPosition);
|
data[hashValue].setStation(startPosition, endPosition);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
data[hashValue].update(value);
|
data[hashValue].update(value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean equalsToStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition, byte[] station) {
|
boolean equalsToStation(long startPosition, long endPosition, byte[] station) {
|
||||||
if (endPosition - startPosition != station.length) {
|
if (endPosition - startPosition != station.length) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < station.length; ++i, ++startPosition) {
|
for (int i = 0; i < station.length; ++i, ++startPosition) {
|
||||||
if (byteBuffer.get(startPosition) != station[i])
|
if (UNSAFE.getByte(startPosition) != station[i])
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -185,25 +194,38 @@ class OpenHash {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public class CalculateAverage_bufistov {
|
static final Unsafe UNSAFE;
|
||||||
|
|
||||||
|
static {
|
||||||
|
try {
|
||||||
|
Field unsafe = Unsafe.class.getDeclaredField("theUnsafe");
|
||||||
|
unsafe.setAccessible(true);
|
||||||
|
UNSAFE = (Unsafe) unsafe.get(Unsafe.class);
|
||||||
|
}
|
||||||
|
catch (Throwable e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static final long LINE_SEPARATOR = '\n';
|
static final long LINE_SEPARATOR = '\n';
|
||||||
|
|
||||||
public static class FileRead implements Callable<HashMap<ByteArrayWrapper, ResultRow>> {
|
public static class FileRead implements Callable<HashMap<ByteArrayWrapper, ResultRow>> {
|
||||||
|
|
||||||
private final FileChannel fileChannel;
|
private final FileChannel fileChannel;
|
||||||
|
|
||||||
private long currentLocation;
|
private long currentLocation;
|
||||||
private int bytesToRead;
|
private long bytesToRead;
|
||||||
|
|
||||||
private final int hashCapacityPow2 = 18;
|
private static final int hashCapacityPow2 = 18;
|
||||||
private final int hashCapacityMask = (1 << hashCapacityPow2) - 1;
|
|
||||||
|
|
||||||
public FileRead(long startLocation, int bytesToRead, FileChannel fileChannel) {
|
static final int hashCapacityMask = (1 << hashCapacityPow2) - 1;
|
||||||
|
|
||||||
|
public FileRead(FileChannel fileChannel, long startLocation, long bytesToRead, boolean firstSegment) {
|
||||||
|
this.fileChannel = fileChannel;
|
||||||
this.currentLocation = startLocation;
|
this.currentLocation = startLocation;
|
||||||
this.bytesToRead = bytesToRead;
|
this.bytesToRead = bytesToRead;
|
||||||
this.fileChannel = fileChannel;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -211,21 +233,13 @@ public class CalculateAverage_bufistov {
|
|||||||
try {
|
try {
|
||||||
OpenHash openHash = new OpenHash(hashCapacityPow2);
|
OpenHash openHash = new OpenHash(hashCapacityPow2);
|
||||||
log("Reading the channel: " + currentLocation + ":" + bytesToRead);
|
log("Reading the channel: " + currentLocation + ":" + bytesToRead);
|
||||||
byte[] suffix = new byte[128];
|
|
||||||
if (currentLocation > 0) {
|
if (currentLocation > 0) {
|
||||||
toLineBegin(suffix);
|
toLineBeginPrefix();
|
||||||
}
|
|
||||||
while (bytesToRead > 0) {
|
|
||||||
int bufferSize = Math.min(1 << 24, bytesToRead);
|
|
||||||
MappedByteBuffer byteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, currentLocation, bufferSize);
|
|
||||||
bytesToRead -= bufferSize;
|
|
||||||
currentLocation += bufferSize;
|
|
||||||
int suffixBytes = 0;
|
|
||||||
if (currentLocation < fileChannel.size()) {
|
|
||||||
suffixBytes = toLineBegin(suffix);
|
|
||||||
}
|
|
||||||
processChunk(byteBuffer, bufferSize, suffix, suffixBytes, openHash);
|
|
||||||
}
|
}
|
||||||
|
toLineBeginSuffix();
|
||||||
|
var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, currentLocation, bytesToRead, Arena.global());
|
||||||
|
currentLocation = memorySegment.address();
|
||||||
|
processChunk(openHash);
|
||||||
log("Done Reading the channel: " + currentLocation + ":" + bytesToRead);
|
log("Done Reading the channel: " + currentLocation + ":" + bytesToRead);
|
||||||
return openHash.toJavaHashMap();
|
return openHash.toJavaHashMap();
|
||||||
}
|
}
|
||||||
@ -240,39 +254,40 @@ public class CalculateAverage_bufistov {
|
|||||||
return byteBuffer.get();
|
return byteBuffer.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
int toLineBegin(byte[] suffix) throws IOException {
|
void toLineBeginPrefix() throws IOException {
|
||||||
int bytesConsumed = 0;
|
while (getByte(currentLocation - 1) != LINE_SEPARATOR) {
|
||||||
if (getByte(currentLocation - 1) != LINE_SEPARATOR) {
|
|
||||||
while (getByte(currentLocation) != LINE_SEPARATOR) { // Small bug here if last chunk is less than a line and has no '\n' at the end. Valid input should have '\n' for all rows.
|
|
||||||
suffix[bytesConsumed++] = getByte(currentLocation);
|
|
||||||
++currentLocation;
|
++currentLocation;
|
||||||
--bytesToRead;
|
--bytesToRead;
|
||||||
}
|
}
|
||||||
++currentLocation;
|
|
||||||
--bytesToRead;
|
|
||||||
}
|
|
||||||
return bytesConsumed;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void processChunk(MappedByteBuffer byteBuffer, int bufferSize, byte[] suffix, int suffixBytes, OpenHash result) {
|
void toLineBeginSuffix() throws IOException {
|
||||||
int nameBegin = 0;
|
while (getByte(currentLocation + bytesToRead - 1) != LINE_SEPARATOR) {
|
||||||
int nameEnd = -1;
|
++bytesToRead;
|
||||||
int numberBegin = -1;
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void processChunk(OpenHash result) {
|
||||||
|
long nameBegin = currentLocation;
|
||||||
|
long nameEnd = -1;
|
||||||
|
long numberBegin = -1;
|
||||||
int currentHash = 0;
|
int currentHash = 0;
|
||||||
int currentMask = 0;
|
int currentMask = 0;
|
||||||
int nameHash = 0;
|
int nameHash = 0;
|
||||||
for (int currentPosition = 0; currentPosition < bufferSize; ++currentPosition) {
|
long end = currentLocation + bytesToRead;
|
||||||
byte nextByte = byteBuffer.get(currentPosition);
|
byte nextByte;
|
||||||
|
for (; currentLocation < end; ++currentLocation) {
|
||||||
|
nextByte = UNSAFE.getByte(currentLocation);
|
||||||
if (nextByte == ';') {
|
if (nextByte == ';') {
|
||||||
nameEnd = currentPosition;
|
nameEnd = currentLocation;
|
||||||
numberBegin = currentPosition + 1;
|
numberBegin = currentLocation + 1;
|
||||||
nameHash = currentHash & hashCapacityMask;
|
nameHash = currentHash & hashCapacityMask;
|
||||||
}
|
}
|
||||||
else if (nextByte == LINE_SEPARATOR) {
|
else if (nextByte == LINE_SEPARATOR) {
|
||||||
long value = getValue(byteBuffer, numberBegin, currentPosition);
|
long value = getValue(numberBegin, currentLocation);
|
||||||
// log("Station name: '" + getStationName(byteBuffer, nameBegin, nameEnd) + "' value: " + value + " hash: " + nameHash);
|
// log("Station name: '" + getStationName(nameBegin, nameEnd) + "' value: " + value + " hash: " + nameHash);
|
||||||
result.merge(byteBuffer, nameBegin, nameEnd, nameHash, value);
|
result.merge(nameBegin, nameEnd, nameHash, value);
|
||||||
nameBegin = currentPosition + 1;
|
nameBegin = currentLocation + 1;
|
||||||
currentHash = 0;
|
currentHash = 0;
|
||||||
currentMask = 0;
|
currentMask = 0;
|
||||||
}
|
}
|
||||||
@ -281,38 +296,14 @@ public class CalculateAverage_bufistov {
|
|||||||
currentMask = (currentMask + 1) & 3;
|
currentMask = (currentMask + 1) & 3;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (nameBegin < bufferSize) {
|
|
||||||
byte[] lastLine = new byte[bufferSize - nameBegin + suffixBytes];
|
|
||||||
byte[] prefix = new byte[bufferSize - nameBegin];
|
|
||||||
byteBuffer.slice(nameBegin, prefix.length).get(prefix, 0, prefix.length);
|
|
||||||
System.arraycopy(prefix, 0, lastLine, 0, prefix.length);
|
|
||||||
System.arraycopy(suffix, 0, lastLine, prefix.length, suffixBytes);
|
|
||||||
processLastLine(lastLine, result);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void processLastLine(byte[] lastLine, OpenHash result) {
|
long getValue(long startLocation, long endLocation) {
|
||||||
int numberBegin = -1;
|
byte nextByte = UNSAFE.getByte(startLocation);
|
||||||
byte[] stationName = null;
|
|
||||||
for (int i = 0; i < lastLine.length; ++i) {
|
|
||||||
if (lastLine[i] == ';') {
|
|
||||||
stationName = new byte[i];
|
|
||||||
System.arraycopy(lastLine, 0, stationName, 0, stationName.length);
|
|
||||||
numberBegin = i + 1;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
long value = getValue(lastLine, numberBegin);
|
|
||||||
// log("Station name: '" + new String(stationName, StandardCharsets.UTF_8) + "' value: " + value);
|
|
||||||
result.merge(stationName, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
long getValue(MappedByteBuffer byteBuffer, int startLocation, int endLocation) {
|
|
||||||
byte nextByte = byteBuffer.get(startLocation);
|
|
||||||
boolean negate = nextByte == '-';
|
boolean negate = nextByte == '-';
|
||||||
long result = negate ? 0 : nextByte - '0';
|
long result = negate ? 0 : nextByte - '0';
|
||||||
for (int i = startLocation + 1; i < endLocation; ++i) {
|
for (long i = startLocation + 1; i < endLocation; ++i) {
|
||||||
nextByte = byteBuffer.get(i);
|
nextByte = UNSAFE.getByte(i);
|
||||||
if (nextByte != '.') {
|
if (nextByte != '.') {
|
||||||
result *= 10;
|
result *= 10;
|
||||||
result += nextByte - '0';
|
result += nextByte - '0';
|
||||||
@ -321,23 +312,11 @@ public class CalculateAverage_bufistov {
|
|||||||
return negate ? -result : result;
|
return negate ? -result : result;
|
||||||
}
|
}
|
||||||
|
|
||||||
long getValue(byte[] lastLine, int startLocation) {
|
String getStationName(long from, long to) {
|
||||||
byte nextByte = lastLine[startLocation];
|
byte[] bytes = new byte[(int) (to - from)];
|
||||||
boolean negate = nextByte == '-';
|
for (int i = 0; i < bytes.length; ++i) {
|
||||||
long result = negate ? 0 : nextByte - '0';
|
bytes[i] = UNSAFE.getByte(from + i);
|
||||||
for (int i = startLocation + 1; i < lastLine.length; ++i) {
|
|
||||||
nextByte = lastLine[i];
|
|
||||||
if (nextByte != '.') {
|
|
||||||
result *= 10;
|
|
||||||
result += nextByte - '0';
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return negate ? -result : result;
|
|
||||||
}
|
|
||||||
|
|
||||||
String getStationName(MappedByteBuffer byteBuffer, int from, int to) {
|
|
||||||
byte[] bytes = new byte[to - from];
|
|
||||||
byteBuffer.slice(from, to - from).get(0, bytes);
|
|
||||||
return new String(bytes, StandardCharsets.UTF_8);
|
return new String(bytes, StandardCharsets.UTF_8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -349,7 +328,7 @@ public class CalculateAverage_bufistov {
|
|||||||
}
|
}
|
||||||
log("InputFile: " + fileName);
|
log("InputFile: " + fileName);
|
||||||
FileInputStream fileInputStream = new FileInputStream(fileName);
|
FileInputStream fileInputStream = new FileInputStream(fileName);
|
||||||
int numThreads = 32;
|
int numThreads = 2 * Runtime.getRuntime().availableProcessors();
|
||||||
if (args.length > 1) {
|
if (args.length > 1) {
|
||||||
numThreads = Integer.parseInt(args[1]);
|
numThreads = Integer.parseInt(args[1]);
|
||||||
}
|
}
|
||||||
@ -363,9 +342,12 @@ public class CalculateAverage_bufistov {
|
|||||||
|
|
||||||
long startLocation = 0;
|
long startLocation = 0;
|
||||||
ArrayList<Future<HashMap<ByteArrayWrapper, ResultRow>>> results = new ArrayList<>(numThreads);
|
ArrayList<Future<HashMap<ByteArrayWrapper, ResultRow>>> results = new ArrayList<>(numThreads);
|
||||||
|
var fileChannel = FileChannel.open(Paths.get(fileName));
|
||||||
|
boolean firstSegment = true;
|
||||||
while (remaining_size > 0) {
|
while (remaining_size > 0) {
|
||||||
long actualSize = Math.min(chunk_size, remaining_size);
|
long actualSize = Math.min(chunk_size, remaining_size);
|
||||||
results.add(executor.submit(new FileRead(startLocation, toIntExact(actualSize), channel)));
|
results.add(executor.submit(new FileRead(fileChannel, startLocation, toIntExact(actualSize), firstSegment)));
|
||||||
|
firstSegment = false;
|
||||||
remaining_size -= actualSize;
|
remaining_size -= actualSize;
|
||||||
startLocation += actualSize;
|
startLocation += actualSize;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user