diff options
Diffstat (limited to 'src/PackAWithQuantRowOffset.cc')
-rw-r--r-- | src/PackAWithQuantRowOffset.cc | 37 |
1 files changed, 22 insertions, 15 deletions
diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index 305a298..13a8fad 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -45,32 +45,37 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - rowOffsetAllocatedHere = false; + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { + assert(0 && "unknown architecure"); + } + if (params) { - if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { - BaseType::brow_ = params->MCB; - BaseType::bcol_ = params->KCB; - row_interleave_B_ = params->ROW_INTERLEAVE; - } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); - } + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { + } else { + // AVX2 BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; row_interleave_B_ = PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; - } else { - // TODO: Have default slower path - assert(0 && "unknown architecure"); } } + + rowOffsetAllocatedHere = false; + if (BaseType::numCols() % groups != 0) { throw std::runtime_error( "groups = " + std::to_string(groups) + @@ -202,7 +207,9 @@ int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + } else if (fbgemmHasAvx512Support()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits<T, accT, inst_set_t::avx2>::MCB; |