diff options
Diffstat (limited to 'src/PackAMatrix.cc')
-rw-r--r-- | src/PackAMatrix.cc | 43 |
1 files changed, 27 insertions, 16 deletions
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc index db019db..89ec13e 100644 --- a/src/PackAMatrix.cc +++ b/src/PackAMatrix.cc @@ -20,24 +20,36 @@ PackAMatrix<T, accT>::PackAMatrix( const T* smat, int32_t ld, inpType* pmat, - int groups) - : PackMatrix<PackAMatrix<T, accT>, T, accT>(nRow, nCol, pmat, groups), + int groups, + const BlockingFactors* params) + : PackMatrix<PackAMatrix<T, accT>, T, accT>( + nRow, + nCol, + pmat, + groups, + params), trans_(trans), smat_(smat), ld_(ld) { - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + if (params) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } } if (BaseType::numCols() % groups != 0) { throw std::runtime_error( @@ -46,8 +58,7 @@ PackAMatrix<T, accT>::PackAMatrix( } if (pmat) { BaseType::buf_ = pmat; - } - else { + } else { BaseType::bufAllocatedHere_ = true; BaseType::buf_ = (T*)fbgemmAlignedAlloc( 64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)); |