Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/FbgemmFP16.cc')
-rw-r--r--src/FbgemmFP16.cc12
1 files changed, 10 insertions, 2 deletions
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc
index 868bc1b..2c0eea3 100644
--- a/src/FbgemmFP16.cc
+++ b/src/FbgemmFP16.cc
@@ -244,11 +244,19 @@ FBGEMM_API void cblas_gemm_compute(
auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows;
for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
assert(kernel_nrows * kb < scratchpad->size());
- PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
+ if (m != 1) {
+ PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
+ gp.A = scratchpad->data();
+ } else {
+ // When m == 1, it is actually vector matrix multiplication. We
+ // don't need to do the transposition for packA here. Instead, we
+ // can just pass the pointer of the original A matrix buffer to the
+ // packed A buffer.
+ gp.A = const_cast<float*>(&A[k_ind]);
+ }
int nbcol = n / Bp.blockColSize();
gp.k = kb;
- gp.A = scratchpad->data();
gp.B = &(Bp(k_ind, 0));
gp.beta = &beta_;
gp.accum = accum;