Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/PackBMatrix.cc')
-rw-r--r--src/PackBMatrix.cc25
1 files changed, 17 insertions, 8 deletions
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
index 0990edb..bf43fab 100644
--- a/src/PackBMatrix.cc
+++ b/src/PackBMatrix.cc
@@ -188,7 +188,8 @@ PackBMatrix<T, accT>::PackBMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -197,7 +198,12 @@ PackBMatrix<T, accT>::PackBMatrix(
BaseType::bcol_ = params->NCB;
row_interleave_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB;
row_interleave_ =
@@ -317,14 +323,16 @@ void PackBMatrix<T, accT>::pack_unpack_(
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::pack(const block_type_t& block,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::pack(
+ const block_type_t& block,
+ const BlockingFactors* params) {
pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params);
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::unpack(T* origin_buf,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::unpack(
+ T* origin_buf,
+ const BlockingFactors* params) {
block_type_t blockB{BaseType::packedRowStart(),
BaseType::numPackedRows(),
BaseType::packedColStart(),
@@ -352,8 +360,9 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const {
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::printPackedMatrix(std::string name,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::printPackedMatrix(
+ std::string name,
+ const BlockingFactors* params) {
std::cout << name << ":"
<< "[" << BaseType::numPackedRows() << ", "
<< BaseType::numPackedCols() << "]" << std::endl;