diff options
Diffstat (limited to 'src/PackAWithQuantRowOffset.cc')
-rw-r--r-- | src/PackAWithQuantRowOffset.cc | 61 |
1 files changed, 39 insertions, 22 deletions
diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index 2929ebb..175425f 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -28,12 +28,14 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( float scale, int32_t zero_pt, int groups, - int32_t* row_offset) + int32_t* row_offset, + const BlockingFactors* params) : PackMatrix<PackAWithQuantRowOffset<T, accT>, T, accT>( nRow, nCol, pmat, - groups), + groups, + params), trans_(trans), smat_(smat), ld_(ld), @@ -41,20 +43,30 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( zero_pt_(zero_pt), row_offset_(row_offset) { rowOffsetAllocatedHere = false; - - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + if (params) { + if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } } else { - // TODO: Have default slower path - assert(0 && "unknown architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unknown architecure"); + } } if (BaseType::numCols() % groups != 0) { throw std::runtime_error( @@ -179,15 +191,20 @@ void PackAWithQuantRowOffset<T, accT>::printPackedMatrix(std::string name) { } template <typename T, typename accT> -int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize() { +int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize( + const BlockingFactors* params) { if (cpuinfo_initialize()) { - if (fbgemmHasAvx512Support()) { - return PackingTraits<T, accT, inst_set_t::avx512>::MCB; - } else if (fbgemmHasAvx2Support()) { - return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + if (params) { + return params->MCB; } else { - assert(0 && "unsupported architecture"); - return -1; + if (fbgemmHasAvx512Support()) { + return PackingTraits<T, accT, inst_set_t::avx512>::MCB; + } else if (fbgemmHasAvx2Support()) { + return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + } else { + assert(0 && "unsupported architecture"); + return -1; + } } } else { throw std::runtime_error("Failed to initialize cpuinfo!"); |