Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/FbgemmFP16.cc')
-rw-r--r--src/FbgemmFP16.cc296
1 files changed, 146 insertions, 150 deletions
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc
index d3d5c1f..868bc1b 100644
--- a/src/FbgemmFP16.cc
+++ b/src/FbgemmFP16.cc
@@ -35,23 +35,18 @@ inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) {
struct KernelInfo {
using knl_ptr = funcptr_fp16;
// optimized kernels to cover all cases
- static constexpr array<knl_ptr, 15> kernel = {
+ // 2 in ?x2 should be the same as kernel_ncol_blocks.
+ // Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to
+ // the restrictions of ymm register numbers (16).
+ static constexpr array<knl_ptr, 7> kernel = {
{
nullptr,
- gemmkernel_1x1_AVX2_fA0fB0fC0,
- gemmkernel_2x1_AVX2_fA0fB0fC0,
- gemmkernel_3x1_AVX2_fA0fB0fC0,
- gemmkernel_4x1_AVX2_fA0fB0fC0,
- gemmkernel_5x1_AVX2_fA0fB0fC0,
- gemmkernel_6x1_AVX2_fA0fB0fC0,
- gemmkernel_7x1_AVX2_fA0fB0fC0,
- gemmkernel_8x1_AVX2_fA0fB0fC0,
- gemmkernel_9x1_AVX2_fA0fB0fC0,
- gemmkernel_10x1_AVX2_fA0fB0fC0,
- gemmkernel_11x1_AVX2_fA0fB0fC0,
- gemmkernel_12x1_AVX2_fA0fB0fC0,
- gemmkernel_13x1_AVX2_fA0fB0fC0,
- gemmkernel_14x1_AVX2_fA0fB0fC0
+ gemmkernel_1x2_AVX2_fA0fB0fC0,
+ gemmkernel_2x2_AVX2_fA0fB0fC0,
+ gemmkernel_3x2_AVX2_fA0fB0fC0,
+ gemmkernel_4x2_AVX2_fA0fB0fC0,
+ gemmkernel_5x2_AVX2_fA0fB0fC0,
+ gemmkernel_6x2_AVX2_fA0fB0fC0
}
};
@@ -61,131 +56,131 @@ struct KernelInfo {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{
- {{ { 0, 0 }, { 0, 0 } } },
- {{ { 1, 1 }, { 0, 0 } } },
- {{ { 2, 1 }, { 0, 0 } } },
- {{ { 3, 1 }, { 0, 0 } } },
- {{ { 4, 1 }, { 0, 0 } } },
- {{ { 5, 1 }, { 0, 0 } } },
- {{ { 6, 1 }, { 0, 0 } } },
- {{ { 7, 1 }, { 0, 0 } } },
- {{ { 8, 1 }, { 0, 0 } } },
- {{ { 9, 1 }, { 0, 0 } } },
- {{ { 10, 1 }, { 0, 0 } } },
- {{ { 11, 1 }, { 0, 0 } } },
- {{ { 12, 1 }, { 0, 0 } } },
- {{ { 13, 1 }, { 0, 0 } } },
- {{ { 14, 1 }, { 0, 0 } } },
- {{ { 8, 1 }, { 7, 1 } } },
- {{ { 10, 1 }, { 6, 1 } } },
- {{ { 11, 1 }, { 6, 1 } } },
- {{ { 12, 1 }, { 6, 1 } } },
- {{ { 11, 1 }, { 8, 1 } } },
- {{ { 11, 1 }, { 9, 1 } } },
- {{ { 12, 1 }, { 9, 1 } } },
- {{ { 11, 2 }, { 0, 0 } } },
- {{ { 12, 1 }, { 11, 1 } } },
- {{ { 12, 2 }, { 0, 0 } } },
- {{ { 13, 1 }, { 12, 1 } } },
- {{ { 13, 2 }, { 0, 0 } } },
- {{ { 14, 1 }, { 13, 1 } } },
- {{ { 14, 2 }, { 0, 0 } } },
- {{ { 11, 2 }, { 7, 1 } } },
- {{ { 10, 3 }, { 0, 0 } } },
- {{ { 12, 2 }, { 7, 1 } } },
- {{ { 12, 2 }, { 8, 1 } } },
- {{ { 11, 3 }, { 0, 0 } } },
- {{ { 13, 2 }, { 8, 1 } } },
- {{ { 13, 2 }, { 9, 1 } } },
- {{ { 13, 2 }, { 10, 1 } } },
- {{ { 13, 2 }, { 11, 1 } } },
- {{ { 13, 2 }, { 12, 1 } } },
- {{ { 13, 3 }, { 0, 0 } } },
- {{ { 14, 2 }, { 12, 1 } } },
- {{ { 14, 2 }, { 13, 1 } } },
- {{ { 11, 3 }, { 9, 1 } } },
- {{ { 11, 3 }, { 10, 1 } } },
- {{ { 11, 4 }, { 0, 0 } } },
- {{ { 12, 3 }, { 9, 1 } } },
- {{ { 12, 3 }, { 10, 1 } } },
- {{ { 13, 3 }, { 8, 1 } } },
- {{ { 13, 3 }, { 9, 1 } } },
- {{ { 13, 3 }, { 10, 1 } } },
- {{ { 13, 3 }, { 11, 1 } } },
- {{ { 13, 3 }, { 12, 1 } } },
- {{ { 13, 4 }, { 0, 0 } } },
- {{ { 14, 3 }, { 11, 1 } } },
- {{ { 11, 4 }, { 10, 1 } } },
- {{ { 12, 4 }, { 7, 1 } } },
- {{ { 14, 4 }, { 0, 0 } } },
- {{ { 12, 4 }, { 9, 1 } } },
- {{ { 12, 4 }, { 10, 1 } } },
- {{ { 12, 4 }, { 11, 1 } } },
- {{ { 13, 4 }, { 8, 1 } } },
- {{ { 13, 4 }, { 9, 1 } } },
- {{ { 13, 4 }, { 10, 1 } } },
- {{ { 13, 4 }, { 11, 1 } } },
- {{ { 11, 5 }, { 9, 1 } } },
- {{ { 13, 5 }, { 0, 0 } } },
- {{ { 14, 4 }, { 10, 1 } } },
- {{ { 12, 5 }, { 7, 1 } } },
- {{ { 12, 5 }, { 8, 1 } } },
- {{ { 14, 4 }, { 13, 1 } } },
- {{ { 14, 5 }, { 0, 0 } } },
- {{ { 12, 5 }, { 11, 1 } } },
- {{ { 13, 5 }, { 7, 1 } } },
- {{ { 11, 6 }, { 7, 1 } } },
- {{ { 13, 5 }, { 9, 1 } } },
- {{ { 13, 5 }, { 10, 1 } } },
- {{ { 13, 5 }, { 11, 1 } } },
- {{ { 13, 5 }, { 12, 1 } } },
- {{ { 13, 6 }, { 0, 0 } } },
- {{ { 12, 6 }, { 7, 1 } } },
- {{ { 12, 6 }, { 8, 1 } } },
- {{ { 12, 6 }, { 9, 1 } } },
- {{ { 12, 6 }, { 10, 1 } } },
- {{ { 12, 6 }, { 11, 1 } } },
- {{ { 12, 7 }, { 0, 0 } } },
- {{ { 13, 6 }, { 7, 1 } } },
- {{ { 13, 6 }, { 8, 1 } } },
- {{ { 13, 6 }, { 9, 1 } } },
- {{ { 13, 6 }, { 10, 1 } } },
- {{ { 13, 6 }, { 11, 1 } } },
- {{ { 13, 6 }, { 12, 1 } } },
- {{ { 13, 7 }, { 0, 0 } } },
- {{ { 12, 7 }, { 8, 1 } } },
- {{ { 12, 7 }, { 9, 1 } } },
- {{ { 14, 6 }, { 10, 1 } } },
- {{ { 12, 7 }, { 11, 1 } } },
- {{ { 13, 7 }, { 5, 1 } } },
- {{ { 13, 7 }, { 6, 1 } } },
- {{ { 13, 7 }, { 7, 1 } } },
- {{ { 13, 7 }, { 8, 1 } } },
- {{ { 13, 7 }, { 9, 1 } } },
- {{ { 13, 7 }, { 10, 1 } } },
- {{ { 13, 7 }, { 11, 1 } } },
- {{ { 13, 7 }, { 12, 1 } } },
- {{ { 12, 8 }, { 8, 1 } } },
- {{ { 12, 8 }, { 9, 1 } } },
- {{ { 12, 8 }, { 10, 1 } } },
- {{ { 12, 8 }, { 11, 1 } } },
- {{ { 12, 9 }, { 0, 0 } } },
- {{ { 11, 9 }, { 10, 1 } } },
- {{ { 13, 8 }, { 6, 1 } } },
- {{ { 13, 8 }, { 7, 1 } } },
- {{ { 13, 8 }, { 8, 1 } } },
- {{ { 13, 8 }, { 9, 1 } } },
- {{ { 13, 8 }, { 10, 1 } } },
- {{ { 13, 8 }, { 11, 1 } } },
- {{ { 12, 9 }, { 8, 1 } } },
- {{ { 13, 9 }, { 0, 0 } } },
- {{ { 12, 9 }, { 10, 1 } } },
- {{ { 12, 9 }, { 11, 1 } } },
- {{ { 12, 10 }, { 0, 0 } } }
+ {{ { 0, 0 }, { 0, 0 } } }, // 0
+ {{ { 1, 1 }, { 0, 0 } } }, // 1
+ {{ { 2, 1 }, { 0, 0 } } }, // 2
+ {{ { 3, 1 }, { 0, 0 } } }, // 3
+ {{ { 4, 1 }, { 0, 0 } } }, // 4
+ {{ { 5, 1 }, { 0, 0 } } }, // 5
+ {{ { 6, 1 }, { 0, 0 } } }, // 6
+ {{ { 5, 1 }, { 2, 1 } } }, // 7
+ {{ { 4, 2 }, { 0, 0 } } }, // 8
+ {{ { 5, 1 }, { 4, 1 } } }, // 9
+ {{ { 5, 2 }, { 0, 0 } } }, // 10
+ {{ { 6, 1 }, { 5, 1 } } }, // 11
+ {{ { 6, 2 }, { 0, 0 } } }, // 12
+ {{ { 5, 2 }, { 3, 1 } } }, // 13
+ {{ { 6, 2 }, { 2, 1 } } }, // 14
+ {{ { 5, 3 }, { 0, 0 } } }, // 15
+ {{ { 6, 2 }, { 4, 1 } } }, // 16
+ {{ { 6, 2 }, { 5, 1 } } }, // 17
+ {{ { 6, 3 }, { 0, 0 } } }, // 18
+ {{ { 5, 3 }, { 4, 1 } } }, // 19
+ {{ { 5, 4 }, { 0, 0 } } }, // 20
+ {{ { 5, 3 }, { 6, 1 } } }, // 21
+ {{ { 6, 3 }, { 4, 1 } } }, // 22
+ {{ { 6, 3 }, { 5, 1 } } }, // 23
+ {{ { 6, 4 }, { 0, 0 } } }, // 24
+ {{ { 5, 5 }, { 0, 0 } } }, // 25
+ {{ { 5, 4 }, { 6, 1 } } }, // 26
+ {{ { 6, 4 }, { 3, 1 } } }, // 27
+ {{ { 6, 4 }, { 4, 1 } } }, // 28
+ {{ { 6, 4 }, { 5, 1 } } }, // 29
+ {{ { 6, 5 }, { 0, 0 } } }, // 30
+ {{ { 6, 5 }, { 1, 1 } } }, // 31
+ {{ { 6, 5 }, { 2, 1 } } }, // 32
+ {{ { 6, 5 }, { 3, 1 } } }, // 33
+ {{ { 6, 5 }, { 4, 1 } } }, // 34
+ {{ { 6, 5 }, { 5, 1 } } }, // 35
+ {{ { 6, 6 }, { 0, 0 } } }, // 36
+ {{ { 6, 6 }, { 1, 1 } } }, // 37
+ {{ { 6, 6 }, { 2, 1 } } }, // 38
+ {{ { 6, 6 }, { 3, 1 } } }, // 39
+ {{ { 6, 6 }, { 4, 1 } } }, // 40
+ {{ { 6, 6 }, { 5, 1 } } }, // 41
+ {{ { 6, 7 }, { 0, 0 } } }, // 42
+ {{ { 6, 7 }, { 1, 1 } } }, // 43
+ {{ { 6, 7 }, { 2, 1 } } }, // 44
+ {{ { 6, 7 }, { 3, 1 } } }, // 45
+ {{ { 6, 7 }, { 4, 1 } } }, // 46
+ {{ { 6, 7 }, { 5, 1 } } }, // 47
+ {{ { 6, 8 }, { 0, 0 } } }, // 48
+ {{ { 6, 8 }, { 1, 1 } } }, // 49
+ {{ { 6, 8 }, { 2, 1 } } }, // 50
+ {{ { 6, 8 }, { 3, 1 } } }, // 51
+ {{ { 6, 8 }, { 4, 1 } } }, // 52
+ {{ { 6, 8 }, { 5, 1 } } }, // 53
+ {{ { 6, 9 }, { 0, 0 } } }, // 54
+ {{ { 6, 9 }, { 1, 1 } } }, // 55
+ {{ { 6, 9 }, { 2, 1 } } }, // 56
+ {{ { 6, 9 }, { 3, 1 } } }, // 57
+ {{ { 6, 9 }, { 4, 1 } } }, // 58
+ {{ { 6, 9 }, { 5, 1 } } }, // 59
+ {{ { 6, 10 }, { 0, 0 } } }, // 60
+ {{ { 6, 10 }, { 1, 1 } } }, // 61
+ {{ { 6, 10 }, { 2, 1 } } }, // 62
+ {{ { 6, 10 }, { 3, 1 } } }, // 63
+ {{ { 6, 10 }, { 4, 1 } } }, // 64
+ {{ { 6, 10 }, { 5, 1 } } }, // 65
+ {{ { 6, 11 }, { 0, 0 } } }, // 66
+ {{ { 6, 11 }, { 1, 1 } } }, // 67
+ {{ { 6, 11 }, { 2, 1 } } }, // 68
+ {{ { 6, 11 }, { 3, 1 } } }, // 69
+ {{ { 6, 11 }, { 4, 1 } } }, // 70
+ {{ { 6, 11 }, { 5, 1 } } }, // 71
+ {{ { 6, 12 }, { 0, 0 } } }, // 72
+ {{ { 6, 12 }, { 1, 1 } } }, // 73
+ {{ { 6, 12 }, { 2, 1 } } }, // 74
+ {{ { 6, 12 }, { 3, 1 } } }, // 75
+ {{ { 6, 12 }, { 4, 1 } } }, // 76
+ {{ { 6, 12 }, { 5, 1 } } }, // 77
+ {{ { 6, 13 }, { 0, 0 } } }, // 78
+ {{ { 6, 13 }, { 1, 1 } } }, // 79
+ {{ { 6, 13 }, { 2, 1 } } }, // 80
+ {{ { 6, 13 }, { 3, 1 } } }, // 81
+ {{ { 6, 13 }, { 4, 1 } } }, // 82
+ {{ { 6, 13 }, { 5, 1 } } }, // 83
+ {{ { 6, 14 }, { 0, 0 } } }, // 84
+ {{ { 6, 14 }, { 1, 1 } } }, // 85
+ {{ { 6, 14 }, { 2, 1 } } }, // 86
+ {{ { 6, 14 }, { 3, 1 } } }, // 87
+ {{ { 6, 14 }, { 4, 1 } } }, // 88
+ {{ { 6, 14 }, { 5, 1 } } }, // 89
+ {{ { 6, 15 }, { 0, 0 } } }, // 90
+ {{ { 6, 15 }, { 1, 1 } } }, // 91
+ {{ { 6, 15 }, { 2, 1 } } }, // 92
+ {{ { 6, 15 }, { 3, 1 } } }, // 93
+ {{ { 6, 15 }, { 4, 1 } } }, // 94
+ {{ { 6, 15 }, { 5, 1 } } }, // 95
+ {{ { 6, 16 }, { 0, 0 } } }, // 96
+ {{ { 6, 16 }, { 1, 1 } } }, // 97
+ {{ { 6, 16 }, { 2, 1 } } }, // 98
+ {{ { 6, 16 }, { 3, 1 } } }, // 99
+ {{ { 6, 16 }, { 4, 1 } } }, // 100
+ {{ { 6, 16 }, { 5, 1 } } }, // 101
+ {{ { 6, 17 }, { 0, 0 } } }, // 102
+ {{ { 6, 17 }, { 1, 1 } } }, // 103
+ {{ { 6, 17 }, { 2, 1 } } }, // 104
+ {{ { 6, 17 }, { 3, 1 } } }, // 105
+ {{ { 6, 17 }, { 4, 1 } } }, // 106
+ {{ { 6, 17 }, { 5, 1 } } }, // 107
+ {{ { 6, 18 }, { 0, 0 } } }, // 108
+ {{ { 6, 18 }, { 1, 1 } } }, // 109
+ {{ { 6, 18 }, { 2, 1 } } }, // 110
+ {{ { 6, 18 }, { 3, 1 } } }, // 111
+ {{ { 6, 18 }, { 4, 1 } } }, // 112
+ {{ { 6, 18 }, { 5, 1 } } }, // 113
+ {{ { 6, 19 }, { 0, 0 } } }, // 114
+ {{ { 6, 19 }, { 1, 1 } } }, // 115
+ {{ { 6, 19 }, { 2, 1 } } }, // 116
+ {{ { 6, 19 }, { 3, 1 } } }, // 117
+ {{ { 6, 19 }, { 4, 1 } } }, // 118
+ {{ { 6, 19 }, { 5, 1 } } }, // 119
+ {{ { 6, 20 }, { 0, 0 } } }, // 120
}
};
};
-constexpr array<KernelInfo::knl_ptr, 15> KernelInfo::kernel;
+constexpr array<KernelInfo::knl_ptr, 7> KernelInfo::kernel;
constexpr array<array<array<int, 2>, 2>, 121> KernelInfo::partition;
// autotuned kernel splits for various cases m = 1:mb_max
@@ -208,8 +203,8 @@ FBGEMM_API void cblas_gemm_compute(
const int n = Bp.numCols(), k = Bp.numRows(), ldc = n;
const int mb_max = 120;
constexpr int simd_width = 8;
- constexpr int kernel_ncol_blocks = 1;
- constexpr int kernel_ncols = kernel_ncol_blocks * simd_width;
+ int kernel_ncol_blocks = Bp.kernelNumColBlocks();
+ int kernel_ncols = kernel_ncol_blocks * simd_width;
// private scratchpad storage
static thread_local unique_ptr<std::array<float, 256 * 1024>> scratchpad(
@@ -267,7 +262,7 @@ FBGEMM_API void cblas_gemm_compute(
fbgemmGetRange(
num_threads, thread_id, gp.b_block_cols, 1, jb_begin, jb_end);
gp.B += gp.k * Bp.blockColSize() * jb_begin;
- gp.C += 8 * jb_begin;
+ gp.C += Bp.blockColSize() * jb_begin;
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
KernelInfo::kernel[kernel_nrows](&gp);
@@ -279,7 +274,7 @@ FBGEMM_API void cblas_gemm_compute(
fbgemmGetRange(
num_threads, thread_id, gp.b_block_cols, 1, jb_begin, jb_end);
gp.B += gp.k * Bp.blockColSize() * jb_begin;
- gp.C += 8 * jb_begin;
+ gp.C += Bp.blockColSize() * jb_begin;
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
KernelInfo::kernel[kernel_nrows](&gp);
@@ -291,35 +286,36 @@ FBGEMM_API void cblas_gemm_compute(
// leftover
int rem = n - last_blk_col;
assert(rem < kernel_ncols);
- int b = (rem % simd_width) ? ((rem + simd_width) / simd_width)
- : (rem / simd_width);
- assert(b == 1);
- if ((rem % simd_width) == 0) {
+
+ if ((rem % Bp.blockColSize()) == 0) {
gp.B = &(Bp(k_ind, last_blk_col));
gp.C = &C[m2 * ldc + last_blk_col];
gp.b_block_cols = 1;
KernelInfo::kernel[kernel_nrows](&gp);
} else {
- // small temporary buffer
+ // small temporary buffer: the size should be larger than the
+ // required kernel_nrow x kernel_ncols elements computed in the
+ // registers.
float c_tmp[16 * 24] = {0};
assert((16 * 24) > kernel_nrows * kernel_ncols);
gp.B = &(Bp(k_ind, last_blk_col));
gp.C = c_tmp;
- gp.ldc = 8 * sizeof(C[0]);
+ gp.ldc = kernel_ncols * sizeof(C[0]);
gp.b_block_cols = 1;
KernelInfo::kernel[kernel_nrows](&gp);
for (int i = 0; i < kernel_nrows; i++) {
// Todo: use assembly
for (int j = last_blk_col; j < n; j++) {
assert(
- i * 8 + (j - last_blk_col) <
+ i * kernel_ncols + (j - last_blk_col) <
sizeof(c_tmp) / sizeof(c_tmp[0]));
if (accum == 0) {
- C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)];
+ C[(m2 + i) * ldc + j] =
+ c_tmp[i * kernel_ncols + (j - last_blk_col)];
} else {
C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] +
- c_tmp[i * 8 + (j - last_blk_col)];
+ c_tmp[i * kernel_ncols + (j - last_blk_col)];
}
}
}