Process two consecutive lines at a time (#651)
Use a better hash function Don't return index from temperature parsing extra JVM args Co-authored-by: Ian Preston <ianopolous@protonmail.com>
This commit is contained in:
		| @@ -16,4 +16,6 @@ | ||||
| # | ||||
|  | ||||
| JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector" | ||||
| #-Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 -XX:-UseTransparentHugePages" | ||||
|  | ||||
| java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ianopolousfast | ||||
|   | ||||
| @@ -19,7 +19,6 @@ import jdk.incubator.vector.ByteVector; | ||||
| import jdk.incubator.vector.VectorOperators; | ||||
| import jdk.incubator.vector.VectorSpecies; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.lang.foreign.Arena; | ||||
| import java.lang.foreign.MemorySegment; | ||||
| import java.nio.ByteOrder; | ||||
| @@ -39,10 +38,7 @@ import static java.lang.foreign.ValueLayout.*; | ||||
|  * * read chunks in parallel | ||||
|  * * minimise allocation | ||||
|  * * no unsafe | ||||
|  * | ||||
|  * Timings on 4 core i7-7500U CPU @ 2.70GHz: | ||||
|  * average_baseline: 4m48s | ||||
|  * ianopolous:         13.8s | ||||
|  * * process multiple lines in each thread for better ILP | ||||
| */ | ||||
| public class CalculateAverage_ianopolousfast { | ||||
|  | ||||
| @@ -91,11 +87,22 @@ public class CalculateAverage_ianopolousfast { | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     private static int hashToIndex(long hash, int len) { | ||||
|         // From Thomas Wuerthinger's entry | ||||
|         int hashAsInt = (int) (hash ^ (hash >>> 28)); | ||||
|         int finalHash = (hashAsInt ^ (hashAsInt >>> 15)); | ||||
|         return (finalHash & (len - 1)); | ||||
|     private static final int GOLDEN_RATIO = 0x9E3779B9; | ||||
|     private static final int HASH_LROTATE = 5; | ||||
|  | ||||
|     // hash from giovannicuccu | ||||
|     private static int hash(MemorySegment memorySegment, long start, int len) { | ||||
|         int x; | ||||
|         int y; | ||||
|         if (len >= Integer.BYTES) { | ||||
|             x = memorySegment.get(JAVA_INT_UNALIGNED, start); | ||||
|             y = memorySegment.get(JAVA_INT_UNALIGNED, start + len - Integer.BYTES); | ||||
|         } | ||||
|         else { | ||||
|             x = memorySegment.get(JAVA_BYTE, start); | ||||
|             y = memorySegment.get(JAVA_BYTE, start + len - Byte.BYTES); | ||||
|         } | ||||
|         return (Integer.rotateLeft(x * GOLDEN_RATIO, HASH_LROTATE) ^ y) * GOLDEN_RATIO; | ||||
|     } | ||||
|  | ||||
|     public static Stat createStation(long start, long end, MemorySegment buffer) { | ||||
| @@ -105,8 +112,9 @@ public class CalculateAverage_ianopolousfast { | ||||
|         return new Stat(stationBuffer); | ||||
|     } | ||||
|  | ||||
|     public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, Stat[] stations) { | ||||
|         int index = hashToIndex(hash, MAX_STATIONS); | ||||
|     public static Stat dedupeStation(long start, long end, MemorySegment buffer, Stat[] stations) { | ||||
|         int hash = hash(buffer, start, (int) (end - start)); | ||||
|         int index = hash & (MAX_STATIONS - 1); | ||||
|         Stat match = stations[index]; | ||||
|         while (match != null) { | ||||
|             if (matchingStationBytes(start, end, buffer, match)) | ||||
| @@ -119,37 +127,11 @@ public class CalculateAverage_ianopolousfast { | ||||
|         return res; | ||||
|     } | ||||
|  | ||||
|     static long maskHighBytes(long d, int nbytes) { | ||||
|         return d & (-1L << ((8 - nbytes) * 8)); | ||||
|     } | ||||
|  | ||||
|     public static Stat parseStation(long lineStart, MemorySegment buffer, Stat[] stations) { | ||||
|         ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder()); | ||||
|         int keySize = line.compare(VectorOperators.EQ, ';').firstTrue(); | ||||
|  | ||||
|         long first8 = buffer.get(LONG_LAYOUT, lineStart); | ||||
|         long second8 = 0; | ||||
|         if (keySize <= 8) { | ||||
|             first8 = maskHighBytes(first8, keySize & 0x07); | ||||
|         } | ||||
|         else if (keySize < 16) { | ||||
|             second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); | ||||
|         } | ||||
|         else if (keySize == BYTE_SPECIES.vectorByteSize()) { | ||||
|             while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') { | ||||
|                 keySize++; | ||||
|             } | ||||
|             second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); | ||||
|         } | ||||
|         long hash = first8 ^ second8; // todo include later bytes | ||||
|         return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations); | ||||
|     } | ||||
|  | ||||
|     public static short getMinus(long d) { | ||||
|         return ((d & 0xff00000000000000L) ^ 0x2d00000000000000L) != 0 ? 0 : (short) -1; | ||||
|     } | ||||
|  | ||||
|     public static long processTemperature(long lineSplit, int size, MemorySegment buffer, Stat station) { | ||||
|     public static void processTemperature(long lineSplit, int size, MemorySegment buffer, Stat station) { | ||||
|         long d = buffer.get(LONG_LAYOUT, lineSplit); | ||||
|         // negative is either 0 or -1 | ||||
|         short negative = getMinus(d); | ||||
| @@ -162,10 +144,9 @@ public class CalculateAverage_ianopolousfast { | ||||
|                 100 * (((byte) (d >> 24)) - '0')); | ||||
|         temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty | ||||
|         station.add(temperature); | ||||
|         return lineSplit + size + 1; | ||||
|     } | ||||
|  | ||||
|     private static long parseLine(long lineStart, MemorySegment buffer, Stat[] stations) { | ||||
|     private static int lineSize(long lineStart, MemorySegment buffer) { | ||||
|         ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder()); | ||||
|         int lineSize = line.compare(VectorOperators.EQ, '\n').firstTrue(); | ||||
|         int index = lineSize; | ||||
| @@ -174,33 +155,19 @@ public class CalculateAverage_ianopolousfast { | ||||
|                     ByteOrder.nativeOrder()).compare(VectorOperators.EQ, '\n').firstTrue(); | ||||
|             lineSize += index; | ||||
|         } | ||||
|         int keySize = lineSize - 6 + ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize - 6, | ||||
|                 ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue(); | ||||
|  | ||||
|         long first8 = buffer.get(LONG_LAYOUT, lineStart); | ||||
|         long second8 = 0; | ||||
|         if (keySize <= 8) { | ||||
|             first8 = maskHighBytes(first8, keySize & 0x07); | ||||
|         } | ||||
|         else if (keySize < 16) { | ||||
|             second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); | ||||
|         } | ||||
|         else if (keySize == BYTE_SPECIES.vectorByteSize()) { | ||||
|             while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') { | ||||
|                 keySize++; | ||||
|             } | ||||
|             second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); | ||||
|         } | ||||
|         long hash = first8 ^ second8; // todo include later bytes | ||||
|         Stat station = dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations); | ||||
|         return processTemperature(lineStart + keySize + 1, lineSize - keySize - 1, buffer, station); | ||||
|         return lineSize; | ||||
|     } | ||||
|  | ||||
|     public static Stat[] parseStats(long startByte, long endByte, MemorySegment buffer) { | ||||
|     private static int keySize(int lineSize, long lineStart, MemorySegment buffer) { | ||||
|         return lineSize - 6 + ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize - 6, | ||||
|                 ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue(); | ||||
|     } | ||||
|  | ||||
|     public static Stat[] parseStats(long start1, long end2, MemorySegment buffer) { | ||||
|         // read first partial line | ||||
|         if (startByte > 0) { | ||||
|         if (start1 > 0) { | ||||
|             for (int i = 0; i < MAX_LINE_LENGTH; i++) { | ||||
|                 byte b = buffer.get(JAVA_BYTE, startByte++); | ||||
|                 byte b = buffer.get(JAVA_BYTE, start1++); | ||||
|                 if (b == '\n') { | ||||
|                     break; | ||||
|                 } | ||||
| @@ -213,38 +180,47 @@ public class CalculateAverage_ianopolousfast { | ||||
|         // this allows us to not worry about reading beyond the end | ||||
|         // in the inner loop (reducing branches) | ||||
|         // We need at least the vector lane size bytes back | ||||
|         if (endByte == buffer.byteSize()) { | ||||
|         if (end2 == buffer.byteSize()) { | ||||
|             // reverse at least vector lane width | ||||
|             endByte = Math.max(buffer.byteSize() - BYTE_SPECIES.vectorByteSize(), 0); | ||||
|             while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n') | ||||
|                 endByte--; | ||||
|             end2 = Math.max(buffer.byteSize() - 2 * BYTE_SPECIES.vectorByteSize(), 0); | ||||
|             while (end2 > 0 && buffer.get(JAVA_BYTE, end2) != '\n') | ||||
|                 end2--; | ||||
|  | ||||
|             if (endByte > 0) | ||||
|                 endByte++; | ||||
|             if (end2 > 0) | ||||
|                 end2++; | ||||
|             // copy into a larger buffer to avoid reading off end | ||||
|             MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + BYTE_SPECIES.vectorByteSize()); | ||||
|             for (long i = endByte; i < buffer.byteSize(); i++) | ||||
|                 end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i)); | ||||
|             MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 2 * BYTE_SPECIES.vectorByteSize()); | ||||
|             for (long i = end2; i < buffer.byteSize(); i++) | ||||
|                 end.set(JAVA_BYTE, i - end2, buffer.get(JAVA_BYTE, i)); | ||||
|             int index = 0; | ||||
|             while (endByte + index < buffer.byteSize()) { | ||||
|                 Stat station = parseStation(index, end, stations); | ||||
|                 int tempSize = 3; | ||||
|                 if (end.get(JAVA_BYTE, index + station.namelen + 5) == '\n') | ||||
|                     tempSize = 4; | ||||
|                 if (end.get(JAVA_BYTE, index + station.namelen + 6) == '\n') | ||||
|                     tempSize = 5; | ||||
|                 index = (int) processTemperature(index + station.namelen + 1, tempSize, end, station); | ||||
|             while (end2 + index < buffer.byteSize()) { | ||||
|                 int lineSize1 = lineSize(index, end); | ||||
|                 int semiSearchStart = index + Math.max(0, lineSize1 - 6); | ||||
|                 int keySize1 = semiSearchStart - index + ByteVector.fromMemorySegment(BYTE_SPECIES, end, semiSearchStart, | ||||
|                         ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue(); | ||||
|                 Stat station1 = dedupeStation(index, index + keySize1, end, stations); | ||||
|                 processTemperature(index + keySize1 + 1, lineSize1 - keySize1 - 1, end, station1); | ||||
|                 index += lineSize1 + 1; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         innerloop(startByte, endByte, buffer, stations); | ||||
|         return stations; | ||||
|     } | ||||
|  | ||||
|     private static void innerloop(long startByte, long endByte, MemorySegment buffer, Stat[] stations) { | ||||
|         while (startByte < endByte) { | ||||
|             startByte = parseLine(startByte, buffer, stations); | ||||
|         while (start1 < end2) { | ||||
|             int lineSize1 = lineSize(start1, buffer); | ||||
|             long start2 = start1 + lineSize1 + 1; | ||||
|             int lineSize2 = start2 < end2 ? lineSize(start2, buffer) : 0; | ||||
|             int keySize1 = keySize(lineSize1, start1, buffer); | ||||
|             int keySize2 = keySize(lineSize2, start2, buffer); | ||||
|             Stat station1 = dedupeStation(start1, start1 + keySize1, buffer, stations); | ||||
|             processTemperature(start1 + keySize1 + 1, lineSize1 - keySize1 - 1, buffer, station1); | ||||
|             if (start2 < end2) { | ||||
|                 Stat station2 = dedupeStation(start2, start2 + keySize2, buffer, stations); | ||||
|                 processTemperature(start2 + keySize2 + 1, lineSize2 - keySize2 - 1, buffer, station2); | ||||
|                 start1 = start2 + lineSize2 + 1; | ||||
|             } | ||||
|             else | ||||
|                 start1 += lineSize1 + 1; | ||||
|         } | ||||
|         return stations; | ||||
|     } | ||||
|  | ||||
|     public static class Stat { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user