improved 2nd and final submission (#685)
This commit is contained in:
		| @@ -1,4 +1,4 @@ | |||||||
| #!/bin/sh | #!/bin/bash | ||||||
| # | # | ||||||
| #  Copyright 2023 The original authors | #  Copyright 2023 The original authors | ||||||
| # | # | ||||||
| @@ -19,5 +19,8 @@ | |||||||
| # source "$HOME/.sdkman/bin/sdkman-init.sh" | # source "$HOME/.sdkman/bin/sdkman-init.sh" | ||||||
| # sdk use java 21.0.1-graal 1>&2 | # sdk use java 21.0.1-graal 1>&2 | ||||||
|  |  | ||||||
| JAVA_OPTS="--enable-preview --enable-native-access=ALL-UNNAMED --add-modules jdk.incubator.vector" | JAVA_OPTS="-Xlog:all=off -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 --enable-preview --enable-native-access=ALL-UNNAMED --add-modules jdk.incubator.vector" | ||||||
| java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yourwass |  | ||||||
|  | eval "exec 3< <({ java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yourwass; })" | ||||||
|  | read <&3 result | ||||||
|  | echo -e "$result" | ||||||
|   | |||||||
| @@ -16,6 +16,8 @@ | |||||||
| package dev.morling.onebrc; | package dev.morling.onebrc; | ||||||
|  |  | ||||||
| import java.util.TreeMap; | import java.util.TreeMap; | ||||||
|  | import java.util.concurrent.locks.Lock; | ||||||
|  | import java.util.concurrent.locks.ReentrantLock; | ||||||
| import java.io.IOException; | import java.io.IOException; | ||||||
| import java.lang.foreign.Arena; | import java.lang.foreign.Arena; | ||||||
| import java.lang.foreign.MemorySegment; | import java.lang.foreign.MemorySegment; | ||||||
| @@ -31,18 +33,15 @@ import jdk.incubator.vector.VectorSpecies; | |||||||
| import sun.misc.Unsafe; | import sun.misc.Unsafe; | ||||||
|  |  | ||||||
| public class CalculateAverage_yourwass { | public class CalculateAverage_yourwass { | ||||||
|  |  | ||||||
|     static final class Record { |     static final class Record { | ||||||
|         public String city; |         private long cityAddr; | ||||||
|         public long cityAddr; |         private long cityLength; | ||||||
|         public long cityLength; |         private int min; | ||||||
|         public int min; |         private int max; | ||||||
|         public int max; |         private int count; | ||||||
|         public int count; |         private long sum; | ||||||
|         public long sum; |  | ||||||
|  |  | ||||||
|         Record(final long cityAddr, final long cityLength) { |         Record(final long cityAddr, final long cityLength) { | ||||||
|             this.city = null; |  | ||||||
|             this.cityAddr = cityAddr; |             this.cityAddr = cityAddr; | ||||||
|             this.cityLength = cityLength; |             this.cityLength = cityLength; | ||||||
|             this.min = 1000; |             this.min = 1000; | ||||||
| @@ -62,6 +61,8 @@ public class CalculateAverage_yourwass { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     private final static Lock _mutex = new ReentrantLock(true); | ||||||
|  |     private final static TreeMap<String, Record> aggregateResults = new TreeMap<>(); | ||||||
|     private static short lookupDecimal[]; |     private static short lookupDecimal[]; | ||||||
|     private static byte lookupFraction[]; |     private static byte lookupFraction[]; | ||||||
|     private static byte lookupDotPositive[]; |     private static byte lookupDotPositive[]; | ||||||
| @@ -70,6 +71,8 @@ public class CalculateAverage_yourwass { | |||||||
|     private static final VectorSpecies<Byte> SPECIES = ByteVector.SPECIES_PREFERRED; |     private static final VectorSpecies<Byte> SPECIES = ByteVector.SPECIES_PREFERRED; | ||||||
|     private static final int MAXINDEX = (1 << 16) + 10000; // short hash + max allowed cities for collisions at the end :p |     private static final int MAXINDEX = (1 << 16) + 10000; // short hash + max allowed cities for collisions at the end :p | ||||||
|     private static final String FILE = "measurements.txt"; |     private static final String FILE = "measurements.txt"; | ||||||
|  |     private static long unsafeResults; | ||||||
|  |     private static int RECORDSIZE = 36; | ||||||
|     private static final Unsafe UNSAFE = getUnsafe(); |     private static final Unsafe UNSAFE = getUnsafe(); | ||||||
|  |  | ||||||
|     private static Unsafe getUnsafe() { |     private static Unsafe getUnsafe() { | ||||||
| @@ -113,11 +116,9 @@ public class CalculateAverage_yourwass { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         // open file |         // open file | ||||||
|         final long fileSize, mmapAddr; |         final FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); | ||||||
|         try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { |         final long fileSize = fileChannel.size(); | ||||||
|             fileSize = fileChannel.size(); |         final long mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); | ||||||
|             mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); |  | ||||||
|         } |  | ||||||
|         // VAS: Virtual Address Space, as a MemorySegment upto and including the mmaped file. |         // VAS: Virtual Address Space, as a MemorySegment upto and including the mmaped file. | ||||||
|         // If the mmaped MemorySegment is used for Vector creation as is, then there are two problems: |         // If the mmaped MemorySegment is used for Vector creation as is, then there are two problems: | ||||||
|         // 1) fromMemorySegment takes an offset and not an address, so we have to do arithmetic |         // 1) fromMemorySegment takes an offset and not an address, so we have to do arithmetic | ||||||
| @@ -127,36 +128,24 @@ public class CalculateAverage_yourwass { | |||||||
|         // XXX there lies the possibility for an out of bounds read at the end of file, which is not handled here. |         // XXX there lies the possibility for an out of bounds read at the end of file, which is not handled here. | ||||||
|         VAS = MemorySegment.ofAddress(0).reinterpret(mmapAddr + fileSize + SPECIES.length()); |         VAS = MemorySegment.ofAddress(0).reinterpret(mmapAddr + fileSize + SPECIES.length()); | ||||||
|  |  | ||||||
|         // start and wait for threads to finish |         // allocate memory for results | ||||||
|         final int nThreads = Runtime.getRuntime().availableProcessors(); |         final int nThreads = Runtime.getRuntime().availableProcessors(); | ||||||
|  |         unsafeResults = UNSAFE.allocateMemory(RECORDSIZE * MAXINDEX * nThreads); | ||||||
|  |         UNSAFE.setMemory(unsafeResults, RECORDSIZE * MAXINDEX * nThreads, (byte) 0); | ||||||
|  |  | ||||||
|  |         // start and wait for threads to finish | ||||||
|         Thread[] threadList = new Thread[nThreads]; |         Thread[] threadList = new Thread[nThreads]; | ||||||
|         final Record[][] results = new Record[nThreads][]; |  | ||||||
|         final long chunkSize = fileSize / nThreads; |         final long chunkSize = fileSize / nThreads; | ||||||
|         for (int i = 0; i < nThreads; i++) { |         for (int i = 0; i < nThreads; i++) { | ||||||
|             final int threadIndex = i; |             final int threadIndex = i; | ||||||
|             final long startAddr = mmapAddr + i * chunkSize; |             final long startAddr = mmapAddr + i * chunkSize; | ||||||
|             final long endAddr = (i == nThreads - 1) ? mmapAddr + fileSize : mmapAddr + (i + 1) * chunkSize; |             final long endAddr = (i == nThreads - 1) ? mmapAddr + fileSize : mmapAddr + (i + 1) * chunkSize; | ||||||
|             threadList[i] = new Thread(() -> results[threadIndex] = threadMain(threadIndex, startAddr, endAddr, nThreads)); |             threadList[i] = new Thread(() -> threadMain(threadIndex, startAddr, endAddr, nThreads)); | ||||||
|             threadList[i].start(); |             threadList[i].start(); | ||||||
|         } |         } | ||||||
|         for (int i = 0; i < nThreads; i++) |         for (int i = 0; i < nThreads; i++) | ||||||
|             threadList[i].join(); |             threadList[i].join(); | ||||||
|  |  | ||||||
|         // aggregate results and sort |  | ||||||
|         // TODO have to compare with concurrent-parallel stream structures: |  | ||||||
|         // * concurrent hashtable that have to sort afterwards |  | ||||||
|         // * concurrent skiplist that is sorted but has O(n) insert |  | ||||||
|         // * ..other? |  | ||||||
|         final TreeMap<String, Record> aggregateResults = new TreeMap<>(); |  | ||||||
|         for (int thread = 0; thread < nThreads; thread++) { |  | ||||||
|             for (int index = 0; index < MAXINDEX; index++) { |  | ||||||
|                 Record record = results[thread][index]; |  | ||||||
|                 if (record == null) |  | ||||||
|                     continue; |  | ||||||
|                 aggregateResults.compute(record.city, (k, v) -> (v == null) ? record : v.merge(record)); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         // prepare string and print |         // prepare string and print | ||||||
|         StringBuilder sb = new StringBuilder(); |         StringBuilder sb = new StringBuilder(); | ||||||
|         sb.append("{"); |         sb.append("{"); | ||||||
| @@ -167,12 +156,13 @@ public class CalculateAverage_yourwass { | |||||||
|             float max = record.max; |             float max = record.max; | ||||||
|             max /= 10.f; |             max /= 10.f; | ||||||
|             double avg = Math.round((record.sum * 1.0) / record.count) / 10.; |             double avg = Math.round((record.sum * 1.0) / record.count) / 10.; | ||||||
|             sb.append(record.city).append("=").append(min).append("/").append(avg).append("/").append(max).append(", "); |             sb.append(entry.getKey()).append("=").append(min).append("/").append(avg).append("/").append(max).append(", "); | ||||||
|         } |         } | ||||||
|         int stringLength = sb.length(); |         int stringLength = sb.length(); | ||||||
|         sb.setCharAt(stringLength - 2, '}'); |         sb.setCharAt(stringLength - 2, '}'); | ||||||
|         sb.setCharAt(stringLength - 1, '\n'); |         sb.setCharAt(stringLength - 1, '\n'); | ||||||
|         System.out.print(sb.toString()); |         System.out.print(sb.toString()); | ||||||
|  |         System.out.close(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     private static final boolean citiesDiffer(final long a, final long b, final long len) { |     private static final boolean citiesDiffer(final long a, final long b, final long len) { | ||||||
| @@ -185,7 +175,7 @@ public class CalculateAverage_yourwass { | |||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     private static Record[] threadMain(int id, long startAddr, long endAddr, long nThreads) { |     private static void threadMain(int id, long startAddr, long endAddr, long nThreads) { | ||||||
|         // snap to newlines |         // snap to newlines | ||||||
|         if (id != 0) |         if (id != 0) | ||||||
|             while (UNSAFE.getByte(startAddr++) != '\n') |             while (UNSAFE.getByte(startAddr++) != '\n') | ||||||
| @@ -194,23 +184,24 @@ public class CalculateAverage_yourwass { | |||||||
|             while (UNSAFE.getByte(endAddr++) != '\n') |             while (UNSAFE.getByte(endAddr++) != '\n') | ||||||
|                 ; |                 ; | ||||||
|  |  | ||||||
|  |         final long threadResults = unsafeResults + id * MAXINDEX * RECORDSIZE; | ||||||
|         final Record[] results = new Record[MAXINDEX]; |         final Record[] results = new Record[MAXINDEX]; | ||||||
|         final long VECTORBYTESIZE = SPECIES.length(); |         final long VECTORBYTESIZE = SPECIES.length(); | ||||||
|         final ByteOrder BYTEORDER = ByteOrder.nativeOrder(); |         final ByteOrder BYTEORDER = ByteOrder.nativeOrder(); | ||||||
|         final ByteVector delim = ByteVector.broadcast(SPECIES, ';'); |         final ByteVector delim = ByteVector.broadcast(SPECIES, ';'); | ||||||
|         long nextCityAddr = startAddr; // XXX from these three variables, |         long cityAddr = startAddr; | ||||||
|         long cityAddr = nextCityAddr; // only two are necessary, but if one |         long ptr = 0; | ||||||
|         long ptr = 0; // is eliminated, on my pc the benchmark gets worse.. |         while (cityAddr < endAddr) { | ||||||
|         while (nextCityAddr < endAddr) { |  | ||||||
|             // parse city |             // parse city | ||||||
|             long mask = ByteVector.fromMemorySegment(SPECIES, VAS, nextCityAddr + ptr, BYTEORDER) |             ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER); | ||||||
|                     .compare(VectorOperators.EQ, delim).toLong(); |             long mask = parsed.compare(VectorOperators.EQ, delim).toLong(); | ||||||
|             if (mask == 0) { |             while (mask == 0) { | ||||||
|                 ptr += VECTORBYTESIZE; |                 ptr += VECTORBYTESIZE; | ||||||
|                 continue; |                 mask = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr + ptr, BYTEORDER).compare(VectorOperators.EQ, delim).toLong(); | ||||||
|             } |             } | ||||||
|             final long cityLength = ptr + Long.numberOfTrailingZeros(mask); |             final long cityLength = ptr + Long.numberOfTrailingZeros(mask); | ||||||
|             final long tempAddr = cityAddr + cityLength + 1; |             final long tempAddr = cityAddr + cityLength + 1; | ||||||
|  |             ptr = 0; | ||||||
|  |  | ||||||
|             // compute hash table index |             // compute hash table index | ||||||
|             int index; |             int index; | ||||||
| @@ -222,67 +213,79 @@ public class CalculateAverage_yourwass { | |||||||
|                         & 0xFFFF; |                         & 0xFFFF; | ||||||
|             else |             else | ||||||
|                 index = (UNSAFE.getByte(cityAddr) << 8) & 0xFF00; |                 index = (UNSAFE.getByte(cityAddr) << 8) & 0xFF00; | ||||||
|  |  | ||||||
|             // resolve collisions with linear probing |             // resolve collisions with linear probing | ||||||
|             // use vector api here also, but only if city name fits in one vector length, for faster default case |             // use vector api here also, but only if city name fits in one vector length, for faster default case | ||||||
|             Record record = results[index]; |             long record = threadResults + index * RECORDSIZE; | ||||||
|  |             long recordCityLength = UNSAFE.getLong(record); | ||||||
|             if (cityLength <= VECTORBYTESIZE) { |             if (cityLength <= VECTORBYTESIZE) { | ||||||
|                 ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER); |                 while (recordCityLength > 0) { | ||||||
|                 while (record != null) { |                     if (cityLength == recordCityLength) { | ||||||
|                     if (cityLength == record.cityLength) { |                         long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, UNSAFE.getLong(record + 8), BYTEORDER) | ||||||
|                         long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, record.cityAddr, BYTEORDER) |  | ||||||
|                                 .compare(VectorOperators.EQ, parsed).toLong(); |                                 .compare(VectorOperators.EQ, parsed).toLong(); | ||||||
|                         if (Long.numberOfTrailingZeros(~sameMask) >= cityLength) |                         if (Long.numberOfTrailingZeros(~sameMask) >= cityLength) | ||||||
|                             break; |                             break; | ||||||
|                     } |                     } | ||||||
|                     record = results[++index]; |                     index++; | ||||||
|  |                     record = threadResults + index * RECORDSIZE; | ||||||
|  |                     recordCityLength = UNSAFE.getLong(record); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             else { // slower normal case for city names with length > VECTORBYTESIZE |             else { // slower normal case for city names with length > VECTORBYTESIZE | ||||||
|                 while (record != null && (cityLength != record.cityLength || citiesDiffer(record.cityAddr, cityAddr, cityLength))) |                 while (recordCityLength > 0 && (cityLength != recordCityLength || citiesDiffer(UNSAFE.getLong(record + 8), cityAddr, cityLength))) { | ||||||
|                     record = results[++index]; |                     index++; | ||||||
|  |                     record = threadResults + index * RECORDSIZE; | ||||||
|  |                     recordCityLength = UNSAFE.getLong(record); | ||||||
|  |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             // add record for new keys |             // add record for new key | ||||||
|             // TODO have to avoid memory allocations on hot path |             if (recordCityLength == 0) { | ||||||
|             if (record == null) { |                 UNSAFE.putLong(record, cityLength); | ||||||
|                 results[index] = new Record(cityAddr, cityLength); |                 UNSAFE.putLong(record + 8, cityAddr); | ||||||
|                 record = results[index]; |                 UNSAFE.putInt(record + 16, 1000); | ||||||
|  |                 UNSAFE.putInt(record + 20, -1000); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             // parse temp with lookup tables |             // parse temp with lookup tables | ||||||
|             int temp; |             int temp; | ||||||
|             if (UNSAFE.getByte(tempAddr) == '-') { |             if (UNSAFE.getByte(tempAddr) == '-') { | ||||||
|                 temp = -lookupDecimal[UNSAFE.getShort(tempAddr + 1)] - lookupFraction[UNSAFE.getShort(tempAddr + 3)]; |                 temp = -lookupDecimal[UNSAFE.getShort(tempAddr + 1)] - lookupFraction[UNSAFE.getShort(tempAddr + 3)]; | ||||||
|                 nextCityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)]; |                 cityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)]; | ||||||
|             } |             } | ||||||
|             else { |             else { | ||||||
|                 temp = lookupDecimal[UNSAFE.getShort(tempAddr)] + lookupFraction[UNSAFE.getShort(tempAddr + 2)]; |                 temp = lookupDecimal[UNSAFE.getShort(tempAddr)] + lookupFraction[UNSAFE.getShort(tempAddr + 2)]; | ||||||
|                 nextCityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)]; |                 cityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)]; | ||||||
|             } |             } | ||||||
|             cityAddr = nextCityAddr; |  | ||||||
|             ptr = 0; |  | ||||||
|  |  | ||||||
|             // merge record |             // merge | ||||||
|             if (temp < record.min) |             if (temp < UNSAFE.getInt(record + 16)) | ||||||
|                 record.min = temp; |                 UNSAFE.putInt(record + 16, temp); | ||||||
|             if (temp > record.max) |             if (temp > UNSAFE.getInt(record + 20)) | ||||||
|                 record.max = temp; |                 UNSAFE.putInt(record + 20, temp); | ||||||
|             record.sum += temp; |             UNSAFE.putLong(record + 24, UNSAFE.getLong(record + 24) + temp); | ||||||
|             record.count += 1; |             UNSAFE.putInt(record + 32, UNSAFE.getInt(record + 32) + 1); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // create strings from raw data |         // create strings from raw data | ||||||
|         // TODO should avoid this copy |         // and aggregate results onto TreeMap | ||||||
|  |         int idx = 0; | ||||||
|         byte b[] = new byte[100]; |         byte b[] = new byte[100]; | ||||||
|  |         _mutex.lock(); | ||||||
|         for (int i = 0; i < MAXINDEX; i++) { |         for (int i = 0; i < MAXINDEX; i++) { | ||||||
|             Record r = results[i]; |             if (UNSAFE.getLong(threadResults + i * RECORDSIZE) == 0) | ||||||
|             if (r == null) |  | ||||||
|                 continue; |                 continue; | ||||||
|             UNSAFE.copyMemory(null, r.cityAddr, b, Unsafe.ARRAY_BYTE_BASE_OFFSET, r.cityLength); |             final long recordAddress = threadResults + i * RECORDSIZE; | ||||||
|             r.city = new String(b, 0, (int) r.cityLength, StandardCharsets.UTF_8); |  | ||||||
|         } |  | ||||||
|         return results; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|  |             results[idx] = new Record(UNSAFE.getLong(recordAddress + 8), UNSAFE.getLong(recordAddress)); | ||||||
|  |             results[idx].min = UNSAFE.getInt(recordAddress + 16); | ||||||
|  |             results[idx].max = UNSAFE.getInt(recordAddress + 20); | ||||||
|  |             results[idx].sum = UNSAFE.getLong(recordAddress + 24); | ||||||
|  |             results[idx].count = UNSAFE.getInt(recordAddress + 32); | ||||||
|  |             UNSAFE.copyMemory(null, UNSAFE.getLong(recordAddress + 8), b, Unsafe.ARRAY_BYTE_BASE_OFFSET, UNSAFE.getLong(recordAddress)); | ||||||
|  |             final Record record = results[idx]; | ||||||
|  |             aggregateResults.compute(new String(b, 0, (int) results[idx].cityLength, StandardCharsets.UTF_8), (k, v) -> (v == null) ? record : v.merge(record)); | ||||||
|  |             idx++; | ||||||
|  |         } | ||||||
|  |         _mutex.unlock(); | ||||||
|  |     } | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user