| @@ -15,12 +15,16 @@ | ||||
|  */ | ||||
| package dev.morling.onebrc; | ||||
|  | ||||
| import jdk.incubator.vector.ByteVector; | ||||
| import jdk.incubator.vector.VectorOperators; | ||||
| import jdk.incubator.vector.VectorSpecies; | ||||
| import sun.misc.Unsafe; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.lang.foreign.Arena; | ||||
| import java.lang.foreign.MemorySegment; | ||||
| import java.lang.reflect.Field; | ||||
| import java.nio.ByteOrder; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.charset.StandardCharsets; | ||||
| import java.nio.file.Files; | ||||
| @@ -36,6 +40,7 @@ import java.util.concurrent.ConcurrentHashMap; | ||||
|  * Parse key as byte vs string      30000 ms | ||||
|  * Parse temp as fixed vs double    15000 ms | ||||
|  * HashMap optimization             10000 ms | ||||
|  * Simd + reduce memory copy         8000 ms | ||||
|  * | ||||
|  */ | ||||
| public class CalculateAverage_jincongho { | ||||
| @@ -55,6 +60,115 @@ public class CalculateAverage_jincongho { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Vectorization utilities with 1BRC-specific optimizations | ||||
|      */ | ||||
|     protected static class VectorUtils { | ||||
|  | ||||
|         // key length is usually less than 32 bytes, having more is just expensive | ||||
|         public static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_256; | ||||
|  | ||||
|         /** Vectorized field delimiter search **/ | ||||
|  | ||||
|         public static int findDelimiter(MemorySegment data, long offset) { | ||||
|             return ByteVector.fromMemorySegment(VectorUtils.BYTE_SPECIES, data, offset, ByteOrder.nativeOrder()) | ||||
|                     .compare(VectorOperators.EQ, ';') | ||||
|                     .firstTrue(); | ||||
|         } | ||||
|  | ||||
|         /** Vectorized Hashing (explicit vectorization seems slower, overkill?) **/ | ||||
|  | ||||
|         // private static int[] HASH_ARRAY = initHashArray(); | ||||
|         // private static final IntVector HASH_VECTOR = IntVector.fromArray(IntVector.SPECIES_256, HASH_ARRAY, 0); | ||||
|         // private static final int HASH_ACCUM = HASH_ARRAY[0] * 31; | ||||
|         // | ||||
|         // private static int[] initHashArray() { | ||||
|         // int[] x = new int[IntVector.SPECIES_256.length()]; | ||||
|         // x[x.length - 1] = 1; | ||||
|         // for (int i = x.length - 2; i >= 0; i--) | ||||
|         // x[i] = x[i + 1] * 31; | ||||
|         // | ||||
|         // return x; | ||||
|         // } | ||||
|  | ||||
|         /** | ||||
|          * Ref: https://github.com/PaulSandoz/vector-api-dev-live-10-2021/blob/main/src/main/java/jmh/BytesHashcode.java | ||||
|          * | ||||
|          * Essentially we are doing this calculation: | ||||
|          * h = h * 31 * 31 * 31 * 31 * 31 * 31 * 31 * 31 + | ||||
|          *         a[i + 0] * 31 * 31 * 31 * 31 * 31 * 31 * 31 + | ||||
|          *         a[i + 1] * 31 * 31 * 31 * 31 * 31 * 31 + | ||||
|          *         a[i + 2] * 31 * 31 * 31 * 31 * 31 + | ||||
|          *         a[i + 3] * 31 * 31 * 31 * 31 + | ||||
|          *         a[i + 4] * 31 * 31 * 31 + | ||||
|          *         a[i + 5] * 31 * 31 + | ||||
|          *         a[i + 6] * 31 + | ||||
|          *         a[i + 7]; | ||||
|          */ | ||||
|         // public static int hashCode(MemorySegment array, long offset, short length) { | ||||
|         // int h = 1; | ||||
|         // long i = offset, loopBound = offset + ByteVector.SPECIES_64.loopBound(length), tailBound = offset + length; | ||||
|         // for (; i < loopBound; i += ByteVector.SPECIES_64.length()) { | ||||
|         // // load 8 bytes, into a 64-bit vector | ||||
|         // ByteVector b = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, array, i, ByteOrder.nativeOrder()); | ||||
|         // // convert 8 bytes into 8 ints (hashing calculation needs int!) | ||||
|         // IntVector x = (IntVector) b.castShape(IntVector.SPECIES_256, 0); | ||||
|         // h = h * HASH_ACCUM + x.mul(HASH_VECTOR).reduceLanes(VectorOperators.ADD); | ||||
|         // } | ||||
|         // | ||||
|         // for (; i < tailBound; i++) { | ||||
|         // h = 31 * h + array.get(ValueLayout.JAVA_BYTE, i); | ||||
|         // } | ||||
|         // return h; | ||||
|         // } | ||||
|  | ||||
|         // scalar implementation | ||||
|         public static int hashCode(final MemorySegment array, final long offset, final short length) { | ||||
|             final long limit = offset + length; | ||||
|             int h = 1; | ||||
|             for (long i = offset; i < limit; i++) { | ||||
|                 h = 31 * h + UNSAFE.getByte(array.address() + i); | ||||
|             } | ||||
|             return h; | ||||
|         } | ||||
|  | ||||
|         /** Vectorized Key Comparison **/ | ||||
|  | ||||
|         private static boolean notEquals(MemorySegment a, long aOffset, MemorySegment b, long bOffset, short length, VectorSpecies BYTE_SPECIES) { | ||||
|             final long aLimit = aOffset + length, bLimit = bOffset + length; | ||||
|  | ||||
|             // main loop | ||||
|             long loopBound = bOffset + BYTE_SPECIES.loopBound(length); | ||||
|             for (; bOffset < loopBound; aOffset += BYTE_SPECIES.length(), bOffset += BYTE_SPECIES.length()) { | ||||
|                 ByteVector av = ByteVector.fromMemorySegment(BYTE_SPECIES, a, | ||||
|                         aOffset, ByteOrder.nativeOrder() /* , BYTE_SPECIES.indexInRange(aOffset, Math.min(aOffset + BYTE_SPECIES.length(), aLimit)) */); | ||||
|                 ByteVector bv = ByteVector.fromMemorySegment(BYTE_SPECIES, b, | ||||
|                         bOffset, ByteOrder.nativeOrder() /* , BYTE_SPECIES.indexInRange(bOffset, Math.min(bOffset + BYTE_SPECIES.length(), bLimit)) */); | ||||
|                 if (av.compare(VectorOperators.NE, bv).anyTrue()) | ||||
|                     return true; | ||||
|             } | ||||
|  | ||||
|             // tail cleanup - load last N bytes with mask | ||||
|             if (bOffset < bLimit) { | ||||
|                 ByteVector av = ByteVector.fromMemorySegment(BYTE_SPECIES, a, aOffset, ByteOrder.nativeOrder(), BYTE_SPECIES.indexInRange(aOffset, aLimit)); | ||||
|                 ByteVector bv = ByteVector.fromMemorySegment(BYTE_SPECIES, b, bOffset, ByteOrder.nativeOrder(), BYTE_SPECIES.indexInRange(bOffset, bLimit)); | ||||
|                 if (av.compare(VectorOperators.NE, bv).anyTrue()) | ||||
|                     return true; | ||||
|             } | ||||
|  | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         // scalar implementation | ||||
|         // private static boolean equals(byte[] a, int aOffset, byte[] b, int bOffset, int len) { | ||||
|         // while (bOffset < len) | ||||
|         // if (a[aOffset++] != b[bOffset++]) | ||||
|         // return false; | ||||
|         // return true; | ||||
|         // } | ||||
|  | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Measurement Hash Table (for each partition) | ||||
|      * Uses contiguous byte array to optimize for cache-line (hopefully) | ||||
| @@ -70,26 +184,27 @@ public class CalculateAverage_jincongho { | ||||
|         private static int KEY_MASK = (MAP_SIZE - 1); | ||||
|         private static int VALUE_SIZE = 16; // min (2 bytes) + max ( 2 bytes) + count (4 bytes) + sum (8 bytes) | ||||
|  | ||||
|         private byte[] KEYS = new byte[MAP_SIZE * KEY_SIZE]; | ||||
|         private byte[] VALUES = new byte[MAP_SIZE * VALUE_SIZE]; | ||||
|         private MemorySegment KEYS = Arena.ofShared().allocate(MAP_SIZE * KEY_SIZE, 64); | ||||
|         private MemorySegment VALUES = Arena.ofShared().allocate(MAP_SIZE * VALUE_SIZE, 16); | ||||
|  | ||||
|         public PartitionAggr() { | ||||
|             // init min and max | ||||
|             for (int offset = UNSAFE.ARRAY_BYTE_BASE_OFFSET; offset < UNSAFE.ARRAY_BYTE_BASE_OFFSET + (MAP_SIZE * VALUE_SIZE); offset += VALUE_SIZE) { | ||||
|                 UNSAFE.putShort(VALUES, offset, Short.MAX_VALUE); | ||||
|                 UNSAFE.putShort(VALUES, offset + 2, Short.MIN_VALUE); | ||||
|             final long limit = VALUES.address() + (MAP_SIZE * VALUE_SIZE); | ||||
|             for (long offset = VALUES.address(); offset < limit; offset += VALUE_SIZE) { | ||||
|                 UNSAFE.putShort(offset, Short.MAX_VALUE); | ||||
|                 UNSAFE.putShort(offset + 2, Short.MIN_VALUE); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         public void update(byte[] key, int hash, short keyLength, short value) { | ||||
|             int index = hash & KEY_MASK; | ||||
|             int keyOffset = UNSAFE.ARRAY_BYTE_BASE_OFFSET + (index * KEY_SIZE); | ||||
|             while (((UNSAFE.getShort(KEYS, keyOffset) != keyLength) || | ||||
|                     !equals(KEYS, ((index * KEY_SIZE) + 2), key, 0, keyLength))) { | ||||
|                 if (UNSAFE.getShort(KEYS, keyOffset) == 0) { | ||||
|         public void update(MemorySegment key, long keyStart, short keyLength, int keyHash, short value) { | ||||
|             int index = keyHash & KEY_MASK; | ||||
|             long keyOffset = KEYS.address() + (index * KEY_SIZE); | ||||
|             while (((UNSAFE.getShort(keyOffset) != keyLength) || | ||||
|                     VectorUtils.notEquals(KEYS, ((index * KEY_SIZE) + 2), key, keyStart, keyLength, VectorUtils.BYTE_SPECIES))) { | ||||
|                 if (UNSAFE.getShort(keyOffset) == 0) { | ||||
|                     // put key | ||||
|                     UNSAFE.putShort(KEYS, keyOffset, keyLength); | ||||
|                     UNSAFE.copyMemory(key, UNSAFE.ARRAY_BYTE_BASE_OFFSET, KEYS, keyOffset + 2, keyLength); | ||||
|                     UNSAFE.putShort(keyOffset, keyLength); | ||||
|                     MemorySegment.copy(key, keyStart, KEYS, (index * KEY_SIZE) + 2, keyLength); | ||||
|                     break; | ||||
|                 } | ||||
|                 else { | ||||
| @@ -98,21 +213,14 @@ public class CalculateAverage_jincongho { | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             long valueOffset = UNSAFE.ARRAY_BYTE_BASE_OFFSET + (index * VALUE_SIZE); | ||||
|             UNSAFE.putShort(VALUES, valueOffset, (short) Math.min(UNSAFE.getShort(VALUES, valueOffset), value)); | ||||
|             long valueOffset = VALUES.address() + (index * VALUE_SIZE); | ||||
|             UNSAFE.putShort(valueOffset, (short) Math.min(UNSAFE.getShort(valueOffset), value)); | ||||
|             valueOffset += 2; | ||||
|             UNSAFE.putShort(VALUES, valueOffset, (short) Math.max(UNSAFE.getShort(VALUES, valueOffset), value)); | ||||
|             UNSAFE.putShort(valueOffset, (short) Math.max(UNSAFE.getShort(valueOffset), value)); | ||||
|             valueOffset += 2; | ||||
|             UNSAFE.putInt(VALUES, valueOffset, UNSAFE.getInt(VALUES, valueOffset) + 1); | ||||
|             UNSAFE.putInt(valueOffset, UNSAFE.getInt(valueOffset) + 1); | ||||
|             valueOffset += 4; | ||||
|             UNSAFE.putLong(VALUES, valueOffset, UNSAFE.getLong(VALUES, valueOffset) + value); | ||||
|         } | ||||
|  | ||||
|         private boolean equals(byte[] a, int aOffset, byte[] b, int bOffset, int len) { | ||||
|             while (bOffset < len) | ||||
|                 if (a[aOffset++] != b[bOffset++]) | ||||
|                     return false; | ||||
|             return true; | ||||
|             UNSAFE.putLong(valueOffset, UNSAFE.getLong(valueOffset) + value); | ||||
|         } | ||||
|  | ||||
|         public void mergeTo(ResultAggr result) { | ||||
| @@ -120,24 +228,22 @@ public class CalculateAverage_jincongho { | ||||
|             short keyLength; | ||||
|             for (int i = 0; i < MAP_SIZE; i++) { | ||||
|                 // extract key | ||||
|                 keyOffset = UNSAFE.ARRAY_BYTE_BASE_OFFSET + (i * KEY_SIZE); | ||||
|                 if ((keyLength = UNSAFE.getShort(KEYS, keyOffset)) == 0) | ||||
|                 keyOffset = KEYS.address() + (i * KEY_SIZE); | ||||
|                 if ((keyLength = UNSAFE.getShort(keyOffset)) == 0) | ||||
|                     continue; | ||||
|  | ||||
|                 // extract values (if key is not null) | ||||
|                 final long valueOffset = UNSAFE.ARRAY_BYTE_BASE_OFFSET + (i * VALUE_SIZE); | ||||
|                 result.compute(new String(KEYS, (i * KEY_SIZE) + 2, keyLength, StandardCharsets.UTF_8), (k, v) -> { | ||||
|                     short min = UNSAFE.getShort(VALUES, valueOffset); | ||||
|                     short max = UNSAFE.getShort(VALUES, valueOffset + 2); | ||||
|                     int count = UNSAFE.getInt(VALUES, valueOffset + 4); | ||||
|                     long sum = UNSAFE.getLong(VALUES, valueOffset + 8); | ||||
|  | ||||
|                 final long valueOffset = VALUES.address() + (i * VALUE_SIZE); | ||||
|                 result.compute(new ResultAggr.ByteKey(KEYS, (i * KEY_SIZE) + 2, keyLength), (k, v) -> { | ||||
|                     if (v == null) { | ||||
|                         return new ResultAggr.Measurement(min, max, count, sum); | ||||
|                     } | ||||
|                     else { | ||||
|                         return v.update(min, max, count, sum); | ||||
|                         v = new ResultAggr.Measurement(); | ||||
|                     } | ||||
|                     v.min = (short) Math.min(UNSAFE.getShort(valueOffset), v.min); | ||||
|                     v.max = (short) Math.max(UNSAFE.getShort(valueOffset + 2), v.max); | ||||
|                     v.count += UNSAFE.getInt(valueOffset + 4); | ||||
|                     v.sum += UNSAFE.getLong(valueOffset + 8); | ||||
|  | ||||
|                     return v; | ||||
|                 }); | ||||
|             } | ||||
|         } | ||||
| @@ -148,29 +254,55 @@ public class CalculateAverage_jincongho { | ||||
|      * Measurement Aggregation (for all partitions) | ||||
|      * Simple Concurrent Hash Table so all partitions can merge concurrently | ||||
|      */ | ||||
|     protected static class ResultAggr extends ConcurrentHashMap<String, ResultAggr.Measurement> { | ||||
|     protected static class ResultAggr extends ConcurrentHashMap<ResultAggr.ByteKey, ResultAggr.Measurement> { | ||||
|  | ||||
|         public static class ByteKey implements Comparable<ByteKey> { | ||||
|             private final MemorySegment data; | ||||
|             private final long offset; | ||||
|             private final short length; | ||||
|             private String str; | ||||
|  | ||||
|             public ByteKey(MemorySegment data, long offset, short length) { | ||||
|                 this.data = data; | ||||
|                 this.offset = offset; | ||||
|                 this.length = length; | ||||
|             } | ||||
|  | ||||
|             @Override | ||||
|             public boolean equals(Object other) { | ||||
|                 if (length != ((ByteKey) other).length) | ||||
|                     return false; | ||||
|  | ||||
|                 return !VectorUtils.notEquals(data, offset, ((ByteKey) other).data, ((ByteKey) other).offset, length, VectorUtils.BYTE_SPECIES); | ||||
|             } | ||||
|  | ||||
|             @Override | ||||
|             public int hashCode() { | ||||
|                 return VectorUtils.hashCode(data, offset, length); | ||||
|             } | ||||
|  | ||||
|             @Override | ||||
|             public String toString() { | ||||
|                 if (str == null) { | ||||
|                     // finally has to do a copy! | ||||
|                     byte[] copy = new byte[length]; | ||||
|                     MemorySegment.copy(data, offset, MemorySegment.ofArray(copy), 0, length); | ||||
|                     str = new String(copy, StandardCharsets.UTF_8); | ||||
|                 } | ||||
|                 return str; | ||||
|             } | ||||
|  | ||||
|             @Override | ||||
|             public int compareTo(ByteKey o) { | ||||
|                 return toString().compareTo(o.toString()); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         protected static class Measurement { | ||||
|             public short min; | ||||
|             public short max; | ||||
|             public int count; | ||||
|             public long sum; | ||||
|  | ||||
|             public Measurement(short min, short max, int count, long sum) { | ||||
|                 this.min = min; | ||||
|                 this.max = max; | ||||
|                 this.count = count; | ||||
|                 this.sum = sum; | ||||
|             } | ||||
|  | ||||
|             public ResultAggr.Measurement update(short min, short max, int count, long sum) { | ||||
|                 this.min = (short) Math.min(min, this.min); | ||||
|                 this.max = (short) Math.max(max, this.max); | ||||
|                 this.count += count; | ||||
|                 this.sum += sum; | ||||
|  | ||||
|                 return this; | ||||
|             } | ||||
|             public short min = Short.MAX_VALUE; | ||||
|             public short max = Short.MIN_VALUE; | ||||
|             public int count = 0; | ||||
|             public long sum = 0; | ||||
|  | ||||
|             @Override | ||||
|             public String toString() { | ||||
| @@ -179,6 +311,10 @@ public class CalculateAverage_jincongho { | ||||
|  | ||||
|         } | ||||
|  | ||||
|         public ResultAggr(int initialCapacity, float loadFactor, int concurrencyLevel) { | ||||
|             super(initialCapacity, loadFactor, concurrencyLevel); | ||||
|         } | ||||
|  | ||||
|         public Map toSorted() { | ||||
|             return new TreeMap(this); | ||||
|         } | ||||
| @@ -194,8 +330,8 @@ public class CalculateAverage_jincongho { | ||||
|  | ||||
|         public Partition(MemorySegment data, long offset, long limit, ResultAggr result) { | ||||
|             this.data = data; | ||||
|             this.offset = data.address() + offset; | ||||
|             this.limit = data.address() + limit; | ||||
|             this.offset = offset; | ||||
|             this.limit = limit; | ||||
|             this.result = result; | ||||
|         } | ||||
|  | ||||
| @@ -203,25 +339,57 @@ public class CalculateAverage_jincongho { | ||||
|         public void run() { | ||||
|             // measurement parsing | ||||
|             PartitionAggr aggr = new PartitionAggr(); | ||||
|             byte[] stationName = new byte[128]; | ||||
|             short stationLength; | ||||
|             int hash; | ||||
|             byte tempBuffer; | ||||
|             while (offset < limit) { | ||||
|  | ||||
|             // main loop (vectorized) | ||||
|             final long loopLimit = limit - (VectorUtils.BYTE_SPECIES.length() * Math.ceilDiv(100, VectorUtils.BYTE_SPECIES.length()) + Long.BYTES); | ||||
|             while (offset < loopLimit) { | ||||
|                 long offsetStart = offset; | ||||
|  | ||||
|                 // find station name upto ";" | ||||
|                 hash = 1; | ||||
|                 stationLength = 0; | ||||
|                 while ((stationName[stationLength] = UNSAFE.getByte(offset++)) != ';') | ||||
|                     hash = hash * 31 + stationName[stationLength++]; | ||||
|                 int found; | ||||
|                 do { | ||||
|                     found = VectorUtils.findDelimiter(data, offset); | ||||
|                     offset += found; | ||||
|                 } while (found == VectorUtils.BYTE_SPECIES.length()); | ||||
|                 short stationLength = (short) (offset - offsetStart); | ||||
|                 int stationHash = VectorUtils.hashCode(data, offsetStart, stationLength); | ||||
|  | ||||
|                 // find measurement upto "\n" (credit: merykitty) | ||||
|                 long numberBits = UNSAFE.getLong(data.address() + ++offset); | ||||
|                 final long invNumberBits = ~numberBits; | ||||
|                 final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBits & 0x10101000); | ||||
|  | ||||
|                 int shift = 28 - decimalSepPos; | ||||
|                 long signed = (invNumberBits << 59) >> 63; | ||||
|                 long designMask = ~(signed & 0xFF); | ||||
|                 long digits = ((numberBits & designMask) << shift) & 0x0F000F0F00L; | ||||
|                 long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; | ||||
|  | ||||
|                 short fixed = (short) ((absValue ^ signed) - signed); | ||||
|                 offset += (decimalSepPos >>> 3) + 3; | ||||
|  | ||||
|                 // update measurement | ||||
|                 aggr.update(data, offsetStart, stationLength, stationHash, fixed); | ||||
|             } | ||||
|  | ||||
|             // tail loop (simple) | ||||
|             while (offset < limit) { | ||||
|                 long offsetStart = offset; | ||||
|  | ||||
|                 // find station name upto ";" | ||||
|                 short stationLength = 0; | ||||
|                 while (UNSAFE.getByte(data.address() + offset++) != ';') | ||||
|                     stationLength++; | ||||
|                 int stationHash = VectorUtils.hashCode(data, offsetStart, stationLength); | ||||
|  | ||||
|                 // find measurement upto "\n" | ||||
|                 tempBuffer = UNSAFE.getByte(offset++); | ||||
|                 byte tempBuffer = UNSAFE.getByte(data.address() + offset++); | ||||
|                 boolean isNegative = (tempBuffer == '-'); | ||||
|                 short fixed = (short) (isNegative ? 0 : (tempBuffer - '0')); | ||||
|                 while (true) { | ||||
|                     tempBuffer = UNSAFE.getByte(offset++); | ||||
|                     tempBuffer = UNSAFE.getByte(data.address() + offset++); | ||||
|                     if (tempBuffer == '.') { | ||||
|                         fixed = (short) (fixed * 10 + (UNSAFE.getByte(offset) - '0')); | ||||
|                         fixed = (short) (fixed * 10 + (UNSAFE.getByte(data.address() + offset) - '0')); | ||||
|                         offset += 2; | ||||
|                         break; | ||||
|                     } | ||||
| @@ -230,7 +398,7 @@ public class CalculateAverage_jincongho { | ||||
|                 fixed = isNegative ? (short) -fixed : fixed; | ||||
|  | ||||
|                 // update measurement | ||||
|                 aggr.update(stationName, hash, stationLength, fixed); | ||||
|                 aggr.update(data, offsetStart, stationLength, stationHash, fixed); | ||||
|             } | ||||
|  | ||||
|             // measurement result collection | ||||
| @@ -259,13 +427,15 @@ public class CalculateAverage_jincongho { | ||||
|                     partition[i + 1] = data.byteSize(); | ||||
|                     break; | ||||
|                 } | ||||
|  | ||||
|                 // note: vectorize this made performance worse :( | ||||
|                 while (UNSAFE.getByte(data.address() + partition[i + 1]++) != '\n') | ||||
|                     ; | ||||
|             } | ||||
|  | ||||
|             // partition aggregation | ||||
|             var threadList = new Thread[processors]; | ||||
|             ResultAggr result = new ResultAggr(); | ||||
|             ResultAggr result = new ResultAggr(1 << 14, 1, processors); | ||||
|             for (int i = 0; i < processors; i++) { | ||||
|                 threadList[i] = new Thread(new Partition(data, partition[i], partition[i + 1], result)); | ||||
|                 threadList[i].start(); | ||||
| @@ -282,4 +452,58 @@ public class CalculateAverage_jincongho { | ||||
|  | ||||
|     } | ||||
|  | ||||
|     /** Unit Tests **/ | ||||
|  | ||||
|     public static void testMain(String[] args) { | ||||
|         testHashCode(); | ||||
|         testNotEquals(); | ||||
|     } | ||||
|  | ||||
|     private static void testHashCode() { | ||||
|         // test key length from 1 to 100 | ||||
|         for (int i = 1; i <= 100; i++) { | ||||
|             byte[] array = new byte[i]; | ||||
|             for (int j = 0; j < i; j++) | ||||
|                 array[j] = (byte) j; | ||||
|  | ||||
|             // compare with java default implementation | ||||
|             assertTrue(VectorUtils.hashCode(MemorySegment.ofArray(array), 0, (short) i) == Arrays.hashCode(array)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static void testNotEquals() { | ||||
|         byte[] a = new byte[128]; | ||||
|         byte[] b = new byte[128]; | ||||
|  | ||||
|         // all equals | ||||
|         for (int i = 1; i < 100; i++) { | ||||
|             a[(i + 2) - 1] = 0; | ||||
|             b[i - 1] = 0; | ||||
|             a[(i + 2)] = 10; | ||||
|             b[i] = 10; | ||||
|             assertTrue(!VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_64)); | ||||
|             assertTrue(!VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_128)); | ||||
|             assertTrue(!VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_256)); | ||||
|             assertTrue(!VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_512)); | ||||
|         } | ||||
|  | ||||
|         // one el not equals | ||||
|         for (int i = 1; i < 100; i++) { | ||||
|             a[(i + 2) - 1] = 0; | ||||
|             b[i - 1] = 0; | ||||
|             a[(i + 2)] = 20; | ||||
|             b[i] = 10; | ||||
|             assertTrue(VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_64)); | ||||
|             assertTrue(VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_128)); | ||||
|             assertTrue(VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_256)); | ||||
|             assertTrue(VectorUtils.notEquals(MemorySegment.ofArray(a), 2, MemorySegment.ofArray(b), 0, (short) 100, ByteVector.SPECIES_512)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static void assertTrue(boolean condition) { | ||||
|         if (!condition) { | ||||
|             throw new RuntimeException("Failed test"); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user