diff options
Diffstat (limited to 'src/FbgemmFP16.cc')
-rw-r--r-- | src/FbgemmFP16.cc | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc index 868bc1b..2c0eea3 100644 --- a/src/FbgemmFP16.cc +++ b/src/FbgemmFP16.cc @@ -244,11 +244,19 @@ FBGEMM_API void cblas_gemm_compute( auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows; for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) { assert(kernel_nrows * kb < scratchpad->size()); - PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data()); + if (m != 1) { + PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data()); + gp.A = scratchpad->data(); + } else { + // When m == 1, it is actually vector matrix multiplication. We + // don't need to do the transposition for packA here. Instead, we + // can just pass the pointer of the original A matrix buffer to the + // packed A buffer. + gp.A = const_cast<float*>(&A[k_ind]); + } int nbcol = n / Bp.blockColSize(); gp.k = kb; - gp.A = scratchpad->data(); gp.B = &(Bp(k_ind, 0)); gp.beta = &beta_; gp.accum = accum; |