Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianyu Huang <jianyuhuang@fb.com>2019-03-08 21:34:32 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-08 21:37:10 +0300
commit50b43162fd1742122d01f2704945c78f13e0d73e (patch)
treed5fee7d82429cd63aa1f8bee3628e822bd010436 /src/FbgemmFP16.cc
parent844dacc267391cd2a725d81c2495636f0765771b (diff)
No need for PackA when m==1 (#83)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/83 When m = 1, PackA is actually not necessary: PackA operations for FP16 in these two libraries are both simply matrix transposition. In this case, we don’t need to do the transposition. We can just pass the pointer of the original A matrix buffer to the packed A buffer. Reviewed By: zhengwy888 Differential Revision: D14299246 fbshipit-source-id: 78a62c5ff3a396b59afb15462efe38461cb71e15
Diffstat (limited to 'src/FbgemmFP16.cc')
-rw-r--r--src/FbgemmFP16.cc12
1 files changed, 10 insertions, 2 deletions
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc
index 868bc1b..2c0eea3 100644
--- a/src/FbgemmFP16.cc
+++ b/src/FbgemmFP16.cc
@@ -244,11 +244,19 @@ FBGEMM_API void cblas_gemm_compute(
auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows;
for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
assert(kernel_nrows * kb < scratchpad->size());
- PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
+ if (m != 1) {
+ PackA(kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
+ gp.A = scratchpad->data();
+ } else {
+ // When m == 1, it is actually vector matrix multiplication. We
+ // don't need to do the transposition for packA here. Instead, we
+ // can just pass the pointer of the original A matrix buffer to the
+ // packed A buffer.
+ gp.A = const_cast<float*>(&A[k_ind]);
+ }
int nbcol = n / Bp.blockColSize();
gp.k = kb;
- gp.A = scratchpad->data();
gp.B = &(Bp(k_ind, 0));
gp.beta = &beta_;
gp.accum = accum;