diff options
Diffstat (limited to 'src/PackBMatrix.cc')
-rw-r--r-- | src/PackBMatrix.cc | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index 0990edb..bf43fab 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -188,7 +188,8 @@ PackBMatrix<T, accT>::PackBMatrix( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -197,7 +198,12 @@ PackBMatrix<T, accT>::PackBMatrix( BaseType::bcol_ = params->NCB; row_interleave_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB; row_interleave_ = @@ -317,14 +323,16 @@ void PackBMatrix<T, accT>::pack_unpack_( } template <typename T, typename accT> -void PackBMatrix<T, accT>::pack(const block_type_t& block, - const BlockingFactors* params) { +void PackBMatrix<T, accT>::pack( + const block_type_t& block, + const BlockingFactors* params) { pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params); } template <typename T, typename accT> -void PackBMatrix<T, accT>::unpack(T* origin_buf, - const BlockingFactors* params) { +void PackBMatrix<T, accT>::unpack( + T* origin_buf, + const BlockingFactors* params) { block_type_t blockB{BaseType::packedRowStart(), BaseType::numPackedRows(), BaseType::packedColStart(), @@ -352,8 +360,9 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const { } template <typename T, typename accT> -void PackBMatrix<T, accT>::printPackedMatrix(std::string name, - const BlockingFactors* params) { +void PackBMatrix<T, accT>::printPackedMatrix( + std::string name, + const BlockingFactors* params) { std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; |