Introducing the vector api. 1s faster on 4 core i7 (#506)
Co-authored-by: Ian Preston <ianopolous@protonmail.com>
This commit is contained in:
		| @@ -15,5 +15,5 @@ | ||||
| #  limitations under the License. | ||||
| # | ||||
|  | ||||
| JAVA_OPTS="--enable-preview" | ||||
| JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector" | ||||
| java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ianopolousfast | ||||
|   | ||||
| @@ -15,6 +15,10 @@ | ||||
|  */ | ||||
| package dev.morling.onebrc; | ||||
|  | ||||
| import jdk.incubator.vector.ByteVector; | ||||
| import jdk.incubator.vector.VectorOperators; | ||||
| import jdk.incubator.vector.VectorSpecies; | ||||
|  | ||||
| import java.lang.foreign.Arena; | ||||
| import java.lang.foreign.MemorySegment; | ||||
| import java.nio.ByteOrder; | ||||
| @@ -30,19 +34,23 @@ import static java.lang.foreign.ValueLayout.*; | ||||
| /* A fast implementation with no unsafe. | ||||
|  * Features: | ||||
|  * * memory mapped file using preview Arena FFI | ||||
|  * * semicolon finding using incubator vector api | ||||
|  * * read chunks in parallel | ||||
|  * * minimise allocation | ||||
|  * * no unsafe | ||||
|  * | ||||
|  * Timings on 4 core i7-7500U CPU @ 2.70GHz: | ||||
|  * average_baseline: 4m48s | ||||
|  * ianopolous:         16s | ||||
|  * ianopolous:         15s | ||||
| */ | ||||
| public class CalculateAverage_ianopolousfast { | ||||
|  | ||||
|     public static final int MAX_LINE_LENGTH = 107; | ||||
|     public static final int MAX_STATIONS = 1 << 14; | ||||
|     private static final OfLong LONG_LAYOUT = JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN); | ||||
|     private static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_PREFERRED.length() >= 32 | ||||
|             ? ByteVector.SPECIES_256 | ||||
|             : ByteVector.SPECIES_128; | ||||
|  | ||||
|     public static void main(String[] args) throws Exception { | ||||
|         Arena arena = Arena.global(); | ||||
| @@ -165,58 +173,40 @@ public class CalculateAverage_ianopolousfast { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     public static long hasSemicolon(long d) { | ||||
|         // from Hacker's Delight page 92 | ||||
|         d = d ^ 0x3b3b3b3b3b3b3b3bL; | ||||
|         long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL; | ||||
|         return ~(y | d | 0x7f7f7f7f7f7f7f7fL); | ||||
|     } | ||||
|  | ||||
|     public static int getSemicolonIndex(long y) { | ||||
|         // from Hacker's Delight page 92 | ||||
|         return Long.numberOfLeadingZeros(y) >> 3; | ||||
|     } | ||||
|  | ||||
|     static long maskHighBytes(long d, int nbytes) { | ||||
|         return d & (-1L << ((8 - nbytes) * 8)); | ||||
|     } | ||||
|  | ||||
|     public static Stat parseStation(long lineStart, MemorySegment buffer, List<List<Stat>> stations) { | ||||
|         // find semicolon and update hash as we go, reading a long at a time | ||||
|         long d = buffer.get(LONG_LAYOUT, lineStart); | ||||
|         long hasSemi = hasSemicolon(d); | ||||
|         if (hasSemi != 0) { | ||||
|             int semiIndex = getSemicolonIndex(hasSemi); | ||||
|             d = maskHighBytes(d, semiIndex); | ||||
|             return dedupeStation8(lineStart, lineStart + semiIndex, d, d, buffer, stations); | ||||
|         } | ||||
|         long first8 = d; | ||||
|         long hash = d; | ||||
|         ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder()); | ||||
|         int keySize = line.compare(VectorOperators.EQ, ';').firstTrue(); | ||||
|  | ||||
|         d = buffer.get(LONG_LAYOUT, lineStart + 8); | ||||
|         hasSemi = hasSemicolon(d); | ||||
|         if (hasSemi != 0) { | ||||
|             int semiIndex = getSemicolonIndex(hasSemi); | ||||
|             if (semiIndex == 0) | ||||
|                 return dedupeStation8(lineStart, lineStart + 8, first8, first8, buffer, stations); | ||||
|             d = maskHighBytes(d, semiIndex); | ||||
|             return dedupeStation16(lineStart, lineStart + 8 + semiIndex, first8 ^ d, first8, d, buffer, stations); | ||||
|         if (keySize == BYTE_SPECIES.vectorByteSize()) { | ||||
|             while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') { | ||||
|                 keySize++; | ||||
|             } | ||||
|             long first8 = buffer.get(LONG_LAYOUT, lineStart); | ||||
|             if (keySize < 8) | ||||
|                 return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations); | ||||
|             long second8 = buffer.get(LONG_LAYOUT, lineStart + 8); | ||||
|             if (keySize < 16) | ||||
|                 return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations); | ||||
|             long hash = first8 ^ second8; // todo include other bytes | ||||
|             return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations); | ||||
|         } | ||||
|  | ||||
|         int index = 8; | ||||
|         long second8 = d; | ||||
|         while (hasSemi == 0) { | ||||
|             hash = hash ^ d; | ||||
|             index += 8; | ||||
|             d = buffer.get(LONG_LAYOUT, lineStart + index); | ||||
|             hasSemi = hasSemicolon(d); | ||||
|         long first8 = buffer.get(LONG_LAYOUT, lineStart); | ||||
|         if (keySize <= 8) { | ||||
|             first8 = maskHighBytes(first8, keySize & 0x07); | ||||
|             return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations); | ||||
|         } | ||||
|         int semiIndex = getSemicolonIndex(hasSemi); | ||||
|         d = maskHighBytes(d, semiIndex); | ||||
|         if (semiIndex > 0) { | ||||
|             hash = hash ^ d; | ||||
|         long second8 = buffer.get(LONG_LAYOUT, lineStart + 8); | ||||
|         if (keySize < 16) { | ||||
|             second8 = maskHighBytes(second8, keySize & 0x07); | ||||
|             return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations); | ||||
|         } | ||||
|         return dedupeStation(lineStart, lineStart + index + semiIndex, hash, first8, second8, buffer, stations); | ||||
|         long hash = first8 ^ second8; // todo include later bytes | ||||
|         return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations); | ||||
|     } | ||||
|  | ||||
|     public static int getDot(long d) { | ||||
| @@ -266,24 +256,30 @@ public class CalculateAverage_ianopolousfast { | ||||
|         for (int i = 0; i < MAX_STATIONS; i++) | ||||
|             stations.add(null); | ||||
|  | ||||
|         // Handle reading the very last line in the file | ||||
|         // this allows us to not worry about reading a long beyond the end | ||||
|         // Handle reading the very last few lines in the file | ||||
|         // this allows us to not worry about reading beyond the end | ||||
|         // in the inner loop (reducing branches) | ||||
|         // We only need to read one because the min record size is 6 bytes | ||||
|         // so 2nd last record must be > 8 from end | ||||
|         // We need at least the vector lane size bytes back | ||||
|         if (endByte == buffer.byteSize()) { | ||||
|             endByte -= 2; // skip final new line | ||||
|             endByte -= 1; // skip final new line | ||||
|             // reverse at least vector lane width | ||||
|             while (endByte > 0 && buffer.byteSize() - endByte < BYTE_SPECIES.vectorByteSize()) { | ||||
|                 endByte--; | ||||
|                 while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n') | ||||
|                     endByte--; | ||||
|             } | ||||
|  | ||||
|             if (endByte > 0) | ||||
|                 endByte++; | ||||
|             // copy into a 8n sized buffer to avoid reading off end | ||||
|             MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 4); | ||||
|             // copy into a larger buffer to avoid reading off end | ||||
|             MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + BYTE_SPECIES.vectorByteSize()); | ||||
|             for (long i = endByte; i < buffer.byteSize(); i++) | ||||
|                 end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i)); | ||||
|             Stat station = parseStation(0, end, stations); | ||||
|             processTemperature(station.name.length + 1, end, station); | ||||
|             int index = 0; | ||||
|             while (endByte + index < buffer.byteSize()) { | ||||
|                 Stat station = parseStation(index, end, stations); | ||||
|                 index = (int) processTemperature(index + station.name.length + 1, end, station); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         while (startByte < endByte) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user