diff options
Diffstat (limited to 'src/FbgemmFP16.cc')
-rw-r--r-- | src/FbgemmFP16.cc | 296 |
1 files changed, 146 insertions, 150 deletions
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc index d3d5c1f..868bc1b 100644 --- a/src/FbgemmFP16.cc +++ b/src/FbgemmFP16.cc @@ -35,23 +35,18 @@ inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) { struct KernelInfo { using knl_ptr = funcptr_fp16; // optimized kernels to cover all cases - static constexpr array<knl_ptr, 15> kernel = { + // 2 in ?x2 should be the same as kernel_ncol_blocks. + // Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to + // the restrictions of ymm register numbers (16). + static constexpr array<knl_ptr, 7> kernel = { { nullptr, - gemmkernel_1x1_AVX2_fA0fB0fC0, - gemmkernel_2x1_AVX2_fA0fB0fC0, - gemmkernel_3x1_AVX2_fA0fB0fC0, - gemmkernel_4x1_AVX2_fA0fB0fC0, - gemmkernel_5x1_AVX2_fA0fB0fC0, - gemmkernel_6x1_AVX2_fA0fB0fC0, - gemmkernel_7x1_AVX2_fA0fB0fC0, - gemmkernel_8x1_AVX2_fA0fB0fC0, - gemmkernel_9x1_AVX2_fA0fB0fC0, - gemmkernel_10x1_AVX2_fA0fB0fC0, - gemmkernel_11x1_AVX2_fA0fB0fC0, - gemmkernel_12x1_AVX2_fA0fB0fC0, - gemmkernel_13x1_AVX2_fA0fB0fC0, - gemmkernel_14x1_AVX2_fA0fB0fC0 + gemmkernel_1x2_AVX2_fA0fB0fC0, + gemmkernel_2x2_AVX2_fA0fB0fC0, + gemmkernel_3x2_AVX2_fA0fB0fC0, + gemmkernel_4x2_AVX2_fA0fB0fC0, + gemmkernel_5x2_AVX2_fA0fB0fC0, + gemmkernel_6x2_AVX2_fA0fB0fC0 } }; @@ -61,131 +56,131 @@ struct KernelInfo { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. { - {{ { 0, 0 }, { 0, 0 } } }, - {{ { 1, 1 }, { 0, 0 } } }, - {{ { 2, 1 }, { 0, 0 } } }, - {{ { 3, 1 }, { 0, 0 } } }, - {{ { 4, 1 }, { 0, 0 } } }, - {{ { 5, 1 }, { 0, 0 } } }, - {{ { 6, 1 }, { 0, 0 } } }, - {{ { 7, 1 }, { 0, 0 } } }, - {{ { 8, 1 }, { 0, 0 } } }, - {{ { 9, 1 }, { 0, 0 } } }, - {{ { 10, 1 }, { 0, 0 } } }, - {{ { 11, 1 }, { 0, 0 } } }, - {{ { 12, 1 }, { 0, 0 } } }, - {{ { 13, 1 }, { 0, 0 } } }, - {{ { 14, 1 }, { 0, 0 } } }, - {{ { 8, 1 }, { 7, 1 } } }, - {{ { 10, 1 }, { 6, 1 } } }, - {{ { 11, 1 }, { 6, 1 } } }, - {{ { 12, 1 }, { 6, 1 } } }, - {{ { 11, 1 }, { 8, 1 } } }, - {{ { 11, 1 }, { 9, 1 } } }, - {{ { 12, 1 }, { 9, 1 } } }, - {{ { 11, 2 }, { 0, 0 } } }, - {{ { 12, 1 }, { 11, 1 } } }, - {{ { 12, 2 }, { 0, 0 } } }, - {{ { 13, 1 }, { 12, 1 } } }, - {{ { 13, 2 }, { 0, 0 } } }, - {{ { 14, 1 }, { 13, 1 } } }, - {{ { 14, 2 }, { 0, 0 } } }, - {{ { 11, 2 }, { 7, 1 } } }, - {{ { 10, 3 }, { 0, 0 } } }, - {{ { 12, 2 }, { 7, 1 } } }, - {{ { 12, 2 }, { 8, 1 } } }, - {{ { 11, 3 }, { 0, 0 } } }, - {{ { 13, 2 }, { 8, 1 } } }, - {{ { 13, 2 }, { 9, 1 } } }, - {{ { 13, 2 }, { 10, 1 } } }, - {{ { 13, 2 }, { 11, 1 } } }, - {{ { 13, 2 }, { 12, 1 } } }, - {{ { 13, 3 }, { 0, 0 } } }, - {{ { 14, 2 }, { 12, 1 } } }, - {{ { 14, 2 }, { 13, 1 } } }, - {{ { 11, 3 }, { 9, 1 } } }, - {{ { 11, 3 }, { 10, 1 } } }, - {{ { 11, 4 }, { 0, 0 } } }, - {{ { 12, 3 }, { 9, 1 } } }, - {{ { 12, 3 }, { 10, 1 } } }, - {{ { 13, 3 }, { 8, 1 } } }, - {{ { 13, 3 }, { 9, 1 } } }, - {{ { 13, 3 }, { 10, 1 } } }, - {{ { 13, 3 }, { 11, 1 } } }, - {{ { 13, 3 }, { 12, 1 } } }, - {{ { 13, 4 }, { 0, 0 } } }, - {{ { 14, 3 }, { 11, 1 } } }, - {{ { 11, 4 }, { 10, 1 } } }, - {{ { 12, 4 }, { 7, 1 } } }, - {{ { 14, 4 }, { 0, 0 } } }, - {{ { 12, 4 }, { 9, 1 } } }, - {{ { 12, 4 }, { 10, 1 } } }, - {{ { 12, 4 }, { 11, 1 } } }, - {{ { 13, 4 }, { 8, 1 } } }, - {{ { 13, 4 }, { 9, 1 } } }, - {{ { 13, 4 }, { 10, 1 } } }, - {{ { 13, 4 }, { 11, 1 } } }, - {{ { 11, 5 }, { 9, 1 } } }, - {{ { 13, 5 }, { 0, 0 } } }, - {{ { 14, 4 }, { 10, 1 } } }, - {{ { 12, 5 }, { 7, 1 } } }, - {{ { 12, 5 }, { 8, 1 } } }, - {{ { 14, 4 }, { 13, 1 } } }, - {{ { 14, 5 }, { 0, 0 } } }, - {{ { 12, 5 }, { 11, 1 } } }, - {{ { 13, 5 }, { 7, 1 } } }, - {{ { 11, 6 }, { 7, 1 } } }, - {{ { 13, 5 }, { 9, 1 } } }, - {{ { 13, 5 }, { 10, 1 } } }, - {{ { 13, 5 }, { 11, 1 } } }, - {{ { 13, 5 }, { 12, 1 } } }, - {{ { 13, 6 }, { 0, 0 } } }, - {{ { 12, 6 }, { 7, 1 } } }, - {{ { 12, 6 }, { 8, 1 } } }, - {{ { 12, 6 }, { 9, 1 } } }, - {{ { 12, 6 }, { 10, 1 } } }, - {{ { 12, 6 }, { 11, 1 } } }, - {{ { 12, 7 }, { 0, 0 } } }, - {{ { 13, 6 }, { 7, 1 } } }, - {{ { 13, 6 }, { 8, 1 } } }, - {{ { 13, 6 }, { 9, 1 } } }, - {{ { 13, 6 }, { 10, 1 } } }, - {{ { 13, 6 }, { 11, 1 } } }, - {{ { 13, 6 }, { 12, 1 } } }, - {{ { 13, 7 }, { 0, 0 } } }, - {{ { 12, 7 }, { 8, 1 } } }, - {{ { 12, 7 }, { 9, 1 } } }, - {{ { 14, 6 }, { 10, 1 } } }, - {{ { 12, 7 }, { 11, 1 } } }, - {{ { 13, 7 }, { 5, 1 } } }, - {{ { 13, 7 }, { 6, 1 } } }, - {{ { 13, 7 }, { 7, 1 } } }, - {{ { 13, 7 }, { 8, 1 } } }, - {{ { 13, 7 }, { 9, 1 } } }, - {{ { 13, 7 }, { 10, 1 } } }, - {{ { 13, 7 }, { 11, 1 } } }, - {{ { 13, 7 }, { 12, 1 } } }, - {{ { 12, 8 }, { 8, 1 } } }, - {{ { 12, 8 }, { 9, 1 } } }, - {{ { 12, 8 }, { 10, 1 } } }, - {{ { 12, 8 }, { 11, 1 } } }, - {{ { 12, 9 }, { 0, 0 } } }, - {{ { 11, 9 }, { 10, 1 } } }, - {{ { 13, 8 }, { 6, 1 } } }, - {{ { 13, 8 }, { 7, 1 } } }, - {{ { 13, 8 }, { 8, 1 } } }, - {{ { 13, 8 }, { 9, 1 } } }, - {{ { 13, 8 }, { 10, 1 } } }, - {{ { 13, 8 }, { 11, 1 } } }, - {{ { 12, 9 }, { 8, 1 } } }, - {{ { 13, 9 }, { 0, 0 } } }, - {{ { 12, 9 }, { 10, 1 } } }, - {{ { 12, 9 }, { 11, 1 } } }, - {{ { 12, 10 }, { 0, 0 } } } + {{ { 0, 0 }, { 0, 0 } } }, // 0 + {{ { 1, 1 }, { 0, 0 } } }, // 1 + {{ { 2, 1 }, { 0, 0 } } }, // 2 + {{ { 3, 1 }, { 0, 0 } } }, // 3 + {{ { 4, 1 }, { 0, 0 } } }, // 4 + {{ { 5, 1 }, { 0, 0 } } }, // 5 + {{ { 6, 1 }, { 0, 0 } } }, // 6 + {{ { 5, 1 }, { 2, 1 } } }, // 7 + {{ { 4, 2 }, { 0, 0 } } }, // 8 + {{ { 5, 1 }, { 4, 1 } } }, // 9 + {{ { 5, 2 }, { 0, 0 } } }, // 10 + {{ { 6, 1 }, { 5, 1 } } }, // 11 + {{ { 6, 2 }, { 0, 0 } } }, // 12 + {{ { 5, 2 }, { 3, 1 } } }, // 13 + {{ { 6, 2 }, { 2, 1 } } }, // 14 + {{ { 5, 3 }, { 0, 0 } } }, // 15 + {{ { 6, 2 }, { 4, 1 } } }, // 16 + {{ { 6, 2 }, { 5, 1 } } }, // 17 + {{ { 6, 3 }, { 0, 0 } } }, // 18 + {{ { 5, 3 }, { 4, 1 } } }, // 19 + {{ { 5, 4 }, { 0, 0 } } }, // 20 + {{ { 5, 3 }, { 6, 1 } } }, // 21 + {{ { 6, 3 }, { 4, 1 } } }, // 22 + {{ { 6, 3 }, { 5, 1 } } }, // 23 + {{ { 6, 4 }, { 0, 0 } } }, // 24 + {{ { 5, 5 }, { 0, 0 } } }, // 25 + {{ { 5, 4 }, { 6, 1 } } }, // 26 + {{ { 6, 4 }, { 3, 1 } } }, // 27 + {{ { 6, 4 }, { 4, 1 } } }, // 28 + {{ { 6, 4 }, { 5, 1 } } }, // 29 + {{ { 6, 5 }, { 0, 0 } } }, // 30 + {{ { 6, 5 }, { 1, 1 } } }, // 31 + {{ { 6, 5 }, { 2, 1 } } }, // 32 + {{ { 6, 5 }, { 3, 1 } } }, // 33 + {{ { 6, 5 }, { 4, 1 } } }, // 34 + {{ { 6, 5 }, { 5, 1 } } }, // 35 + {{ { 6, 6 }, { 0, 0 } } }, // 36 + {{ { 6, 6 }, { 1, 1 } } }, // 37 + {{ { 6, 6 }, { 2, 1 } } }, // 38 + {{ { 6, 6 }, { 3, 1 } } }, // 39 + {{ { 6, 6 }, { 4, 1 } } }, // 40 + {{ { 6, 6 }, { 5, 1 } } }, // 41 + {{ { 6, 7 }, { 0, 0 } } }, // 42 + {{ { 6, 7 }, { 1, 1 } } }, // 43 + {{ { 6, 7 }, { 2, 1 } } }, // 44 + {{ { 6, 7 }, { 3, 1 } } }, // 45 + {{ { 6, 7 }, { 4, 1 } } }, // 46 + {{ { 6, 7 }, { 5, 1 } } }, // 47 + {{ { 6, 8 }, { 0, 0 } } }, // 48 + {{ { 6, 8 }, { 1, 1 } } }, // 49 + {{ { 6, 8 }, { 2, 1 } } }, // 50 + {{ { 6, 8 }, { 3, 1 } } }, // 51 + {{ { 6, 8 }, { 4, 1 } } }, // 52 + {{ { 6, 8 }, { 5, 1 } } }, // 53 + {{ { 6, 9 }, { 0, 0 } } }, // 54 + {{ { 6, 9 }, { 1, 1 } } }, // 55 + {{ { 6, 9 }, { 2, 1 } } }, // 56 + {{ { 6, 9 }, { 3, 1 } } }, // 57 + {{ { 6, 9 }, { 4, 1 } } }, // 58 + {{ { 6, 9 }, { 5, 1 } } }, // 59 + {{ { 6, 10 }, { 0, 0 } } }, // 60 + {{ { 6, 10 }, { 1, 1 } } }, // 61 + {{ { 6, 10 }, { 2, 1 } } }, // 62 + {{ { 6, 10 }, { 3, 1 } } }, // 63 + {{ { 6, 10 }, { 4, 1 } } }, // 64 + {{ { 6, 10 }, { 5, 1 } } }, // 65 + {{ { 6, 11 }, { 0, 0 } } }, // 66 + {{ { 6, 11 }, { 1, 1 } } }, // 67 + {{ { 6, 11 }, { 2, 1 } } }, // 68 + {{ { 6, 11 }, { 3, 1 } } }, // 69 + {{ { 6, 11 }, { 4, 1 } } }, // 70 + {{ { 6, 11 }, { 5, 1 } } }, // 71 + {{ { 6, 12 }, { 0, 0 } } }, // 72 + {{ { 6, 12 }, { 1, 1 } } }, // 73 + {{ { 6, 12 }, { 2, 1 } } }, // 74 + {{ { 6, 12 }, { 3, 1 } } }, // 75 + {{ { 6, 12 }, { 4, 1 } } }, // 76 + {{ { 6, 12 }, { 5, 1 } } }, // 77 + {{ { 6, 13 }, { 0, 0 } } }, // 78 + {{ { 6, 13 }, { 1, 1 } } }, // 79 + {{ { 6, 13 }, { 2, 1 } } }, // 80 + {{ { 6, 13 }, { 3, 1 } } }, // 81 + {{ { 6, 13 }, { 4, 1 } } }, // 82 + {{ { 6, 13 }, { 5, 1 } } }, // 83 + {{ { 6, 14 }, { 0, 0 } } }, // 84 + {{ { 6, 14 }, { 1, 1 } } }, // 85 + {{ { 6, 14 }, { 2, 1 } } }, // 86 + {{ { 6, 14 }, { 3, 1 } } }, // 87 + {{ { 6, 14 }, { 4, 1 } } }, // 88 + {{ { 6, 14 }, { 5, 1 } } }, // 89 + {{ { 6, 15 }, { 0, 0 } } }, // 90 + {{ { 6, 15 }, { 1, 1 } } }, // 91 + {{ { 6, 15 }, { 2, 1 } } }, // 92 + {{ { 6, 15 }, { 3, 1 } } }, // 93 + {{ { 6, 15 }, { 4, 1 } } }, // 94 + {{ { 6, 15 }, { 5, 1 } } }, // 95 + {{ { 6, 16 }, { 0, 0 } } }, // 96 + {{ { 6, 16 }, { 1, 1 } } }, // 97 + {{ { 6, 16 }, { 2, 1 } } }, // 98 + {{ { 6, 16 }, { 3, 1 } } }, // 99 + {{ { 6, 16 }, { 4, 1 } } }, // 100 + {{ { 6, 16 }, { 5, 1 } } }, // 101 + {{ { 6, 17 }, { 0, 0 } } }, // 102 + {{ { 6, 17 }, { 1, 1 } } }, // 103 + {{ { 6, 17 }, { 2, 1 } } }, // 104 + {{ { 6, 17 }, { 3, 1 } } }, // 105 + {{ { 6, 17 }, { 4, 1 } } }, // 106 + {{ { 6, 17 }, { 5, 1 } } }, // 107 + {{ { 6, 18 }, { 0, 0 } } }, // 108 + {{ { 6, 18 }, { 1, 1 } } }, // 109 + {{ { 6, 18 }, { 2, 1 } } }, // 110 + {{ { 6, 18 }, { 3, 1 } } }, // 111 + {{ { 6, 18 }, { 4, 1 } } }, // 112 + {{ { 6, 18 }, { 5, 1 } } }, // 113 + {{ { 6, 19 }, { 0, 0 } } }, // 114 + {{ { 6, 19 }, { 1, 1 } } }, // 115 + {{ { 6, 19 }, { 2, 1 } } }, // 116 + {{ { 6, 19 }, { 3, 1 } } }, // 117 + {{ { 6, 19 }, { 4, 1 } } }, // 118 + {{ { 6, 19 }, { 5, 1 } } }, // 119 + {{ { 6, 20 }, { 0, 0 } } }, // 120 } }; }; -constexpr array<KernelInfo::knl_ptr, 15> KernelInfo::kernel; +constexpr array<KernelInfo::knl_ptr, 7> KernelInfo::kernel; constexpr array<array<array<int, 2>, 2>, 121> KernelInfo::partition; // autotuned kernel splits for various cases m = 1:mb_max @@ -208,8 +203,8 @@ FBGEMM_API void cblas_gemm_compute( const int n = Bp.numCols(), k = Bp.numRows(), ldc = n; const int mb_max = 120; constexpr int simd_width = 8; - constexpr int kernel_ncol_blocks = 1; - constexpr int kernel_ncols = kernel_ncol_blocks * simd_width; + int kernel_ncol_blocks = Bp.kernelNumColBlocks(); + int kernel_ncols = kernel_ncol_blocks * simd_width; // private scratchpad storage static thread_local unique_ptr<std::array<float, 256 * 1024>> scratchpad( @@ -267,7 +262,7 @@ FBGEMM_API void cblas_gemm_compute( fbgemmGetRange( num_threads, thread_id, gp.b_block_cols, 1, jb_begin, jb_end); gp.B += gp.k * Bp.blockColSize() * jb_begin; - gp.C += 8 * jb_begin; + gp.C += Bp.blockColSize() * jb_begin; gp.b_block_cols = jb_end - jb_begin; if (gp.b_block_cols) { KernelInfo::kernel[kernel_nrows](&gp); @@ -279,7 +274,7 @@ FBGEMM_API void cblas_gemm_compute( fbgemmGetRange( num_threads, thread_id, gp.b_block_cols, 1, jb_begin, jb_end); gp.B += gp.k * Bp.blockColSize() * jb_begin; - gp.C += 8 * jb_begin; + gp.C += Bp.blockColSize() * jb_begin; gp.b_block_cols = jb_end - jb_begin; if (gp.b_block_cols) { KernelInfo::kernel[kernel_nrows](&gp); @@ -291,35 +286,36 @@ FBGEMM_API void cblas_gemm_compute( // leftover int rem = n - last_blk_col; assert(rem < kernel_ncols); - int b = (rem % simd_width) ? ((rem + simd_width) / simd_width) - : (rem / simd_width); - assert(b == 1); - if ((rem % simd_width) == 0) { + + if ((rem % Bp.blockColSize()) == 0) { gp.B = &(Bp(k_ind, last_blk_col)); gp.C = &C[m2 * ldc + last_blk_col]; gp.b_block_cols = 1; KernelInfo::kernel[kernel_nrows](&gp); } else { - // small temporary buffer + // small temporary buffer: the size should be larger than the + // required kernel_nrow x kernel_ncols elements computed in the + // registers. float c_tmp[16 * 24] = {0}; assert((16 * 24) > kernel_nrows * kernel_ncols); gp.B = &(Bp(k_ind, last_blk_col)); gp.C = c_tmp; - gp.ldc = 8 * sizeof(C[0]); + gp.ldc = kernel_ncols * sizeof(C[0]); gp.b_block_cols = 1; KernelInfo::kernel[kernel_nrows](&gp); for (int i = 0; i < kernel_nrows; i++) { // Todo: use assembly for (int j = last_blk_col; j < n; j++) { assert( - i * 8 + (j - last_blk_col) < + i * kernel_ncols + (j - last_blk_col) < sizeof(c_tmp) / sizeof(c_tmp[0])); if (accum == 0) { - C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)]; + C[(m2 + i) * ldc + j] = + c_tmp[i * kernel_ncols + (j - last_blk_col)]; } else { C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] + - c_tmp[i * 8 + (j - last_blk_col)]; + c_tmp[i * kernel_ncols + (j - last_blk_col)]; } } } |