Version 3 (#455)
This commit is contained in:
		| @@ -24,17 +24,12 @@ import java.lang.foreign.ValueLayout; | ||||
| import java.lang.reflect.Field; | ||||
| import java.nio.channels.FileChannel; | ||||
| import java.nio.file.Paths; | ||||
| import java.util.ArrayList; | ||||
| import java.util.HashMap; | ||||
| import java.util.Map; | ||||
| import java.util.TreeMap; | ||||
| import java.util.stream.IntStream; | ||||
|  | ||||
| public class CalculateAverage_roman_r_m { | ||||
|  | ||||
|     public static final int DOT_3_RD_BYTE_MASK = (byte) '.' << 16; | ||||
|     private static final String FILE = "./measurements.txt"; | ||||
|     private static MemorySegment ms; | ||||
|  | ||||
|     private static Unsafe UNSAFE; | ||||
|  | ||||
| @@ -60,7 +55,7 @@ public class CalculateAverage_roman_r_m { | ||||
|         return match != 0 ? firstSetByteIndex(match) : -1; | ||||
|     } | ||||
|  | ||||
|     static long nextNewline(long from) { | ||||
|     static long nextNewline(long from, MemorySegment ms) { | ||||
|         long start = from; | ||||
|         long i; | ||||
|         long next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, start); | ||||
| @@ -71,6 +66,110 @@ public class CalculateAverage_roman_r_m { | ||||
|         return start + i; | ||||
|     } | ||||
|  | ||||
|     static class Worker { | ||||
|         private final MemorySegment ms; | ||||
|         private final long end; | ||||
|         private long offset; | ||||
|  | ||||
|         public Worker(MemorySegment ms, long start, long end) { | ||||
|             this.ms = ms.asSlice(start, end - start); | ||||
|             this.offset = 0; | ||||
|             this.end = end - start; | ||||
|         } | ||||
|  | ||||
|         private void parseName(ByteString station) { | ||||
|             long start = offset; | ||||
|             long pos = -1; | ||||
|  | ||||
|             while (end - offset > 8) { | ||||
|                 long next = UNSAFE.getLong(ms.address() + offset); | ||||
|                 pos = find(next, SEMICOLON_MASK); | ||||
|                 if (pos >= 0) { | ||||
|                     offset += pos; | ||||
|                     break; | ||||
|                 } | ||||
|                 else { | ||||
|                     offset += 8; | ||||
|                 } | ||||
|             } | ||||
|             if (pos < 0) { | ||||
|                 while (UNSAFE.getByte(ms.address() + offset++) != ';') { | ||||
|                 } | ||||
|                 offset--; | ||||
|             } | ||||
|  | ||||
|             int len = (int) (offset - start); | ||||
|             station.offset = start; | ||||
|             station.len = len; | ||||
|             station.hash = 0; | ||||
|  | ||||
|             offset++; | ||||
|         } | ||||
|  | ||||
|         long parseNumberFast() { | ||||
|             long encodedVal = UNSAFE.getLong(ms.address() + offset); | ||||
|  | ||||
|             var len = find(encodedVal, LINE_END_MASK); | ||||
|             offset += len + 1; | ||||
|  | ||||
|             encodedVal ^= broadcast((byte) 0x30); | ||||
|  | ||||
|             long c0 = len == 4 ? 100 : 10; | ||||
|             long c1 = 10 * (len - 3); | ||||
|             long c2 = 4 - len; | ||||
|             long c3 = len - 3; | ||||
|             long a = (encodedVal & 0xFF) * c0; | ||||
|             long b = ((encodedVal & 0xFF00) >>> 8) * c1; | ||||
|             long c = ((encodedVal & 0xFF0000L) >>> 16) * c2; | ||||
|             long d = ((encodedVal & 0xFF000000L) >>> 24) * c3; | ||||
|  | ||||
|             return a + b + c + d; | ||||
|         } | ||||
|  | ||||
|         long parseNumberSlow() { | ||||
|             long val = UNSAFE.getByte(ms.address() + offset++) - '0'; | ||||
|             byte b; | ||||
|             while ((b = UNSAFE.getByte(ms.address() + offset++)) != '.') { | ||||
|                 val = val * 10 + (b - '0'); | ||||
|             } | ||||
|             b = UNSAFE.getByte(ms.address() + offset); | ||||
|             val = val * 10 + (b - '0'); | ||||
|             offset += 2; | ||||
|             return val; | ||||
|         } | ||||
|  | ||||
|         long parseNumber() { | ||||
|             long val; | ||||
|             int neg = 1 - Integer.bitCount(UNSAFE.getByte(ms.address() + offset) & 0x10); | ||||
|             offset += neg; | ||||
|  | ||||
|             if (end - offset > 8) { | ||||
|                 val = parseNumberFast(); | ||||
|             } | ||||
|             else { | ||||
|                 val = parseNumberSlow(); | ||||
|             } | ||||
|             val *= 1 - 2 * neg; | ||||
|             return val; | ||||
|         } | ||||
|  | ||||
|         public TreeMap<String, ResultRow> run() { | ||||
|             var resultStore = new ResultStore(); | ||||
|             var station = new ByteString(ms); | ||||
|  | ||||
|             while (offset < end) { | ||||
|                 parseName(station); | ||||
|                 long val = parseNumber(); | ||||
|                 var a = resultStore.get(station); | ||||
|                 a.min = Math.min(a.min, val); | ||||
|                 a.max = Math.max(a.max, val); | ||||
|                 a.sum += val; | ||||
|                 a.count++; | ||||
|             } | ||||
|             return resultStore.toMap(); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     public static void main(String[] args) throws Exception { | ||||
|         Field f = Unsafe.class.getDeclaredField("theUnsafe"); | ||||
|         f.setAccessible(true); | ||||
| @@ -79,98 +178,18 @@ public class CalculateAverage_roman_r_m { | ||||
|         long fileSize = new File(FILE).length(); | ||||
|  | ||||
|         var channel = FileChannel.open(Paths.get(FILE)); | ||||
|         ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofAuto()); | ||||
|         MemorySegment ms = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.ofAuto()); | ||||
|  | ||||
|         int numThreads = fileSize > Integer.MAX_VALUE ? Runtime.getRuntime().availableProcessors() : 1; | ||||
|         long chunk = fileSize / numThreads; | ||||
|  | ||||
|         var result = IntStream.range(0, numThreads) | ||||
|                 .parallel() | ||||
|                 .mapToObj(i -> { | ||||
|                     boolean lastChunk = i == numThreads - 1; | ||||
|                     long chunkStart = i == 0 ? 0 : nextNewline(i * chunk) + 1; | ||||
|                     long chunkEnd = lastChunk ? fileSize : nextNewline((i + 1) * chunk); | ||||
|  | ||||
|                     var resultStore = new ResultStore(); | ||||
|                     var station = new ByteString(); | ||||
|  | ||||
|                     long offset = chunkStart; | ||||
|                     while (offset < chunkEnd) { | ||||
|                         long start = offset; | ||||
|                         long pos = -1; | ||||
|  | ||||
|                         while (chunkEnd - offset >= 8) { | ||||
|                             long next = UNSAFE.getLong(ms.address() + offset); | ||||
|                             pos = find(next, SEMICOLON_MASK); | ||||
|                             if (pos >= 0) { | ||||
|                                 offset += pos; | ||||
|                                 break; | ||||
|                             } | ||||
|                             else { | ||||
|                                 offset += 8; | ||||
|                             } | ||||
|                         } | ||||
|                         if (pos < 0) { | ||||
|                             while (UNSAFE.getByte(ms.address() + offset++) != ';') { | ||||
|                             } | ||||
|                             offset--; | ||||
|                         } | ||||
|  | ||||
|                         int len = (int) (offset - start); | ||||
|                         // TODO can we not copy and use a reference into the memory segment to perform table lookup? | ||||
|  | ||||
|                         station.offset = start; | ||||
|                         station.len = len; | ||||
|                         station.hash = 0; | ||||
|  | ||||
|                         offset++; | ||||
|  | ||||
|                         long val; | ||||
|                         boolean neg; | ||||
|                         if (!lastChunk || fileSize - offset >= 8) { | ||||
|                             long encodedVal = UNSAFE.getLong(ms.address() + offset); | ||||
|                             neg = (encodedVal & (byte) '-') == (byte) '-'; | ||||
|                             if (neg) { | ||||
|                                 encodedVal >>= 8; | ||||
|                                 offset++; | ||||
|                             } | ||||
|  | ||||
|                             if ((encodedVal & DOT_3_RD_BYTE_MASK) == DOT_3_RD_BYTE_MASK) { | ||||
|                                 val = (encodedVal & 0xFF - 0x30) * 100 + (encodedVal >> 8 & 0xFF - 0x30) * 10 + (encodedVal >> 24 & 0xFF - 0x30); | ||||
|                                 offset += 5; | ||||
|                             } | ||||
|                             else { | ||||
|                                 // based on http://0x80.pl/articles/simd-parsing-int-sequences.html#parsing-and-conversion-of-signed-numbers | ||||
|                                 val = Long.compress(encodedVal, 0xFF00FFL) - 0x303030; | ||||
|                                 val = ((val * 2561) >> 8) & 0xff; | ||||
|                                 offset += 4; | ||||
|                             } | ||||
|                         } | ||||
|                         else { | ||||
|                             neg = UNSAFE.getByte(ms.address() + offset) == '-'; | ||||
|                             if (neg) { | ||||
|                                 offset++; | ||||
|                             } | ||||
|                             val = UNSAFE.getByte(ms.address() + offset++) - '0'; | ||||
|                             byte b; | ||||
|                             while ((b = UNSAFE.getByte(ms.address() + offset++)) != '.') { | ||||
|                                 val = val * 10 + (b - '0'); | ||||
|                             } | ||||
|                             b = UNSAFE.getByte(ms.address() + offset); | ||||
|                             val = val * 10 + (b - '0'); | ||||
|                             offset += 2; | ||||
|                         } | ||||
|  | ||||
|                         if (neg) { | ||||
|                             val = -val; | ||||
|                         } | ||||
|  | ||||
|                         var a = resultStore.get(station); | ||||
|                         a.min = Math.min(a.min, val); | ||||
|                         a.max = Math.max(a.max, val); | ||||
|                         a.sum += val; | ||||
|                         a.count++; | ||||
|                     } | ||||
|                     return resultStore.toMap(); | ||||
|                     long chunkStart = i == 0 ? 0 : nextNewline(i * chunk, ms) + 1; | ||||
|                     long chunkEnd = lastChunk ? fileSize : nextNewline((i + 1) * chunk, ms); | ||||
|                     return new Worker(ms, chunkStart, chunkEnd).run(); | ||||
|                 }).reduce((m1, m2) -> { | ||||
|                     m2.forEach((k, v) -> m1.merge(k, v, ResultRow::merge)); | ||||
|                     return m1; | ||||
| @@ -181,19 +200,24 @@ public class CalculateAverage_roman_r_m { | ||||
|  | ||||
|     static final class ByteString { | ||||
|  | ||||
|         private final MemorySegment ms; | ||||
|         private long offset; | ||||
|         private int len = 0; | ||||
|         private int hash = 0; | ||||
|  | ||||
|         ByteString(MemorySegment ms) { | ||||
|             this.ms = ms; | ||||
|         } | ||||
|  | ||||
|         @Override | ||||
|         public String toString() { | ||||
|             var bytes = new byte[len]; | ||||
|             MemorySegment.copy(ms, ValueLayout.JAVA_BYTE, offset, bytes, 0, len); | ||||
|             UNSAFE.copyMemory(null, ms.address() + offset, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, len); | ||||
|             return new String(bytes, 0, len); | ||||
|         } | ||||
|  | ||||
|         public ByteString copy() { | ||||
|             var copy = new ByteString(); | ||||
|             var copy = new ByteString(ms); | ||||
|             copy.offset = this.offset; | ||||
|             copy.len = this.len; | ||||
|             copy.hash = this.hash; | ||||
| @@ -216,13 +240,18 @@ public class CalculateAverage_roman_r_m { | ||||
|  | ||||
|             long base1 = ms.address() + offset; | ||||
|             long base2 = ms.address() + that.offset; | ||||
|             for (; i + 3 < len; i += 4) { | ||||
|                 int i1 = UNSAFE.getInt(base1 + i); | ||||
|                 int i2 = UNSAFE.getInt(base2 + i); | ||||
|                 if (i1 != i2) { | ||||
|             for (; i + 7 < len; i += 8) { | ||||
|                 long l1 = UNSAFE.getLong(base1 + i); | ||||
|                 long l2 = UNSAFE.getLong(base2 + i); | ||||
|                 if (l1 != l2) { | ||||
|                     return false; | ||||
|                 } | ||||
|             } | ||||
|             if (len >= 8) { | ||||
|                 long l1 = UNSAFE.getLong(base1 + len - 8); | ||||
|                 long l2 = UNSAFE.getLong(base2 + len - 8); | ||||
|                 return l1 == l2; | ||||
|             } | ||||
|             for (; i < len; i++) { | ||||
|                 byte i1 = UNSAFE.getByte(base1 + i); | ||||
|                 byte i2 = UNSAFE.getByte(base2 + i); | ||||
| @@ -236,10 +265,9 @@ public class CalculateAverage_roman_r_m { | ||||
|         @Override | ||||
|         public int hashCode() { | ||||
|             if (hash == 0) { | ||||
|                 // not sure why but it seems to be working a bit better | ||||
|                 hash = UNSAFE.getInt(ms.address() + offset); | ||||
|                 hash = hash >>> (8 * Math.max(0, 4 - len)); | ||||
|                 hash |= len; | ||||
|                 long h = UNSAFE.getLong(ms.address() + offset); | ||||
|                 h = Long.reverseBytes(h) >>> (8 * Math.max(0, 8 - len)); | ||||
|                 hash = (int) (h ^ (h >>> 32)); | ||||
|             } | ||||
|             return hash; | ||||
|         } | ||||
| @@ -269,25 +297,40 @@ public class CalculateAverage_roman_r_m { | ||||
|     } | ||||
|  | ||||
|     static class ResultStore { | ||||
|         private final ArrayList<ResultRow> results = new ArrayList<>(10000); | ||||
|         private final Map<ByteString, Integer> indices = new HashMap<>(10000); | ||||
|         private static final int SIZE = 16384; | ||||
|         private final ByteString[] keys = new ByteString[SIZE]; | ||||
|         private final ResultRow[] values = new ResultRow[SIZE]; | ||||
|  | ||||
|         ResultRow get(ByteString s) { | ||||
|             var idx = indices.get(s); | ||||
|             if (idx != null) { | ||||
|                 return results.get(idx); | ||||
|             int h = s.hashCode(); | ||||
|             int idx = (SIZE - 1) & h; | ||||
|  | ||||
|             int i = 0; | ||||
|             while (keys[idx] != null && !keys[idx].equals(s)) { | ||||
|                 i++; | ||||
|                 idx = (idx + i * i) % SIZE; | ||||
|             } | ||||
|             ResultRow result; | ||||
|             if (keys[idx] == null) { | ||||
|                 keys[idx] = s.copy(); | ||||
|                 result = new ResultRow(); | ||||
|                 values[idx] = result; | ||||
|             } | ||||
|             else { | ||||
|                 ResultRow next = new ResultRow(); | ||||
|                 results.add(next); | ||||
|                 indices.put(s.copy(), results.size() - 1); | ||||
|                 return next; | ||||
|                 result = values[idx]; | ||||
|                 // TODO see it it makes any difference | ||||
|                 // keys[idx].offset = s.offset; | ||||
|             } | ||||
|             return result; | ||||
|         } | ||||
|  | ||||
|         TreeMap<String, ResultRow> toMap() { | ||||
|             var result = new TreeMap<String, ResultRow>(); | ||||
|             indices.forEach((name, idx) -> result.put(name.toString(), results.get(idx))); | ||||
|             for (int i = 0; i < SIZE; i++) { | ||||
|                 if (keys[i] != null) { | ||||
|                     result.put(keys[i].toString(), values[i]); | ||||
|                 } | ||||
|             } | ||||
|             return result; | ||||
|         } | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user