diff options
Diffstat (limited to 'src/PackAWithRowOffset.cc')
-rw-r--r-- | src/PackAWithRowOffset.cc | 50 |
1 files changed, 31 insertions, 19 deletions
diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc index 7777f1a..139a6d3 100644 --- a/src/PackAWithRowOffset.cc +++ b/src/PackAWithRowOffset.cc @@ -24,31 +24,38 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset( uint32_t ld, inpType* pmat, int groups, - int32_t* row_offset) + int32_t* row_offset, + const BlockingFactors* params) : PackMatrix<PackAWithRowOffset<T, accT>, T, accT>( nRow, nCol, pmat, - groups), + groups, + params), trans_(trans), smat_(smat), ld_(ld), row_offset_(row_offset) { rowOffsetAllocatedHere = false; - - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + if (params) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - // TODO: Have default slower path - assert(0 && "unknown architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unknown architecure"); + } } if (BaseType::numCols() % groups != 0) { throw std::runtime_error( @@ -169,17 +176,22 @@ void PackAWithRowOffset<T, accT>::printPackedMatrix(std::string name) { } template <typename T, typename accT> -int PackAWithRowOffset<T, accT>::rowOffsetBufferSize() { +int PackAWithRowOffset<T, accT>::rowOffsetBufferSize( + const BlockingFactors* params) { if (cpuinfo_initialize()) { + if (params){ + return params->MCB; + } else { if (fbgemmHasAvx512Support()) { - return PackingTraits<T, accT, inst_set_t::avx512>::MCB; - } else if (fbgemmHasAvx2Support()) { - return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + return PackingTraits<T, accT, inst_set_t::avx512>::MCB; + } else if (fbgemmHasAvx2Support()) { + return PackingTraits<T, accT, inst_set_t::avx2>::MCB; } else { // TODO: Have default slower path assert(0 && "unsupported architecture"); return -1; } + } } else { throw std::runtime_error("Failed to initialize cpuinfo!"); } |