Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/PackAMatrix.cc')
-rw-r--r--src/PackAMatrix.cc43
1 files changed, 27 insertions, 16 deletions
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc
index db019db..89ec13e 100644
--- a/src/PackAMatrix.cc
+++ b/src/PackAMatrix.cc
@@ -20,24 +20,36 @@ PackAMatrix<T, accT>::PackAMatrix(
const T* smat,
int32_t ld,
inpType* pmat,
- int groups)
- : PackMatrix<PackAMatrix<T, accT>, T, accT>(nRow, nCol, pmat, groups),
+ int groups,
+ const BlockingFactors* params)
+ : PackMatrix<PackAMatrix<T, accT>, T, accT>(
+ nRow,
+ nCol,
+ pmat,
+ groups,
+ params),
trans_(trans),
smat_(smat),
ld_(ld) {
- if (fbgemmHasAvx512Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx2Support()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ if (params) {
+ BaseType::brow_ = params->MCB;
+ BaseType::bcol_ = params->KCB;
+ row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
+ if (fbgemmHasAvx512Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx2Support()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
}
if (BaseType::numCols() % groups != 0) {
throw std::runtime_error(
@@ -46,8 +58,7 @@ PackAMatrix<T, accT>::PackAMatrix(
}
if (pmat) {
BaseType::buf_ = pmat;
- }
- else {
+ } else {
BaseType::bufAllocatedHere_ = true;
BaseType::buf_ = (T*)fbgemmAlignedAlloc(
64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));