diff options
Diffstat (limited to 'src/PackBMatrix.cc')
-rw-r--r-- | src/PackBMatrix.cc | 260 |
1 files changed, 156 insertions, 104 deletions
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index b6d06ca..c237ac4 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -188,6 +188,76 @@ PackBMatrix<T, accT>::PackBMatrix( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { + assert(0 && "unknown architecure"); + } + + if (params) { + BaseType::brow_ = params->KCB; + BaseType::bcol_ = params->NCB; + row_interleave_ = params->ROW_INTERLEAVE; + } else { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else { + // AVX2 + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } + } + + if (BaseType::numRows() % groups != 0) { + throw std::runtime_error( + "groups = " + std::to_string(groups) + + " does not divide numRows = " + std::to_string(BaseType::numRows())); + } + + // blocking for one group + block_type_t block{ + 0, BaseType::numRows() / BaseType::numGroups(), 0, BaseType::numCols()}; + BaseType::packedBlock(block); + if (!pmat) { + BaseType::bufAllocatedHere_ = true; + BaseType::buf_ = (T*)fbgemmAlignedAlloc( + 64, + BaseType::numGroups() * BaseType::blockRows() * BaseType::brow_ * + BaseType::blockCols() * BaseType::bcol_ * sizeof(T)); + } + pack(block, params); +} + +template <typename T, typename accT> +PackBMatrix<T, accT>::PackBMatrix( + matrix_op_t trans, + int32_t nRow, + int32_t nCol, + inpType* prepackedmat, + int32_t ld, + int groups, + const BlockingFactors* params) + : PackMatrix<PackBMatrix<T, accT>, T, accT>( + nRow, + nCol, + prepackedmat, + groups, + params), + trans_(trans), + smat_(nullptr), + ld_(ld) { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } if (params) { if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { BaseType::brow_ = params->KCB; @@ -221,20 +291,17 @@ PackBMatrix<T, accT>::PackBMatrix( // blocking for one group block_type_t block{ - 0, BaseType::numRows() / BaseType::numGroups(), 0, BaseType::numCols()}; + 0, BaseType::numRows() / BaseType::numGroups(), 0, BaseType::numCols() }; BaseType::packedBlock(block); - if (!pmat) { - BaseType::bufAllocatedHere_ = true; - BaseType::buf_ = (T*)fbgemmAlignedAlloc( - 64, - BaseType::numGroups() * BaseType::blockRows() * BaseType::brow_ * - BaseType::blockCols() * BaseType::bcol_ * sizeof(T)); - } - pack(block); } template <typename T, typename accT> -void PackBMatrix<T, accT>::pack(const block_type_t& block) { +void PackBMatrix<T, accT>::pack_unpack_( + const block_type_t& block, + T* unpack_buf, + T* pack_buf, + bool ispack, + const BlockingFactors* params) { assert((BaseType::blockRowSize() % row_interleave_) == 0); assert((block.row_start % BaseType::blockRowSize()) == 0); assert((block.col_start % BaseType::blockColSize()) == 0); @@ -242,8 +309,8 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { BaseType::packedBlock(block); bool tr = (trans_ == matrix_op_t::Transpose); for (int g = 0; g < BaseType::numGroups(); ++g) { - T* out = BaseType::getBuf() + - g * BaseType::packedBufferSize(block.row_size, block.col_size); + T* pack_buf_cur = pack_buf + + g * BaseType::packedBufferSize(block.row_size, block.col_size, params); for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * (BaseType::blockRowSize() * BaseType::blockColSize()) + @@ -268,10 +335,16 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() + c_idx_offset * row_interleave_; - int out_idx = r_offset + c_offset; - T val = tr ? smat_[i + (g * block.col_size + j) * ld_] - : smat_[(g * block.row_size + i) * ld_ + j]; - out[out_idx] = val; + if (ispack) { + pack_buf_cur[r_offset + c_offset] = tr + ? unpack_buf[i + (g * block.col_size + j) * ld_] + : unpack_buf[(g * block.row_size + i) * ld_ + j]; + } else { + T* unpack_buf_cur = tr + ? &(unpack_buf[i + (g * block.col_size + j) * ld_]) + : &(unpack_buf[(g * block.row_size + i) * ld_ + j]); + *unpack_buf_cur = pack_buf_cur[r_offset + c_offset]; + } c_idx_offset++; if (c_idx_offset == BaseType::blockColSize()) { @@ -280,78 +353,49 @@ void PackBMatrix<T, accT>::pack(const block_type_t& block) { } } } - // fill the remaining with zero. - // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. - for (int i = block.row_start + block.row_size; - i < (block.row_start + block.row_size + row_interleave_ - 1) / - row_interleave_ * row_interleave_; - ++i) { - int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * - (BaseType::blockRowSize() * BaseType::blockColSize()) + - (i % BaseType::blockRowSize() / row_interleave_) * - BaseType::blockColSize() * row_interleave_ + - i % row_interleave_; - for (int j = block.col_start; j < block.col_start + block.col_size; j++) { - int c_offset = (j / BaseType::blockColSize()) * - BaseType::blockRowSize() * BaseType::blockColSize() + - (j % BaseType::blockColSize()) * row_interleave_; + if (ispack) { + // fill the remaining with zero. + // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill. + for (int i = block.row_start + block.row_size; + i < (block.row_start + block.row_size + row_interleave_ - 1) / + row_interleave_ * row_interleave_; + ++i) { + int r_offset = + ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * + (BaseType::blockRowSize() * BaseType::blockColSize()) + + (i % BaseType::blockRowSize() / row_interleave_) * + BaseType::blockColSize() * row_interleave_ + + i % row_interleave_; + for (int j = block.col_start; j < block.col_start + block.col_size; + j++) { + int c_offset = (j / BaseType::blockColSize()) * + BaseType::blockRowSize() * BaseType::blockColSize() + + (j % BaseType::blockColSize()) * row_interleave_; - int out_idx = r_offset + c_offset; - out[out_idx] = 0; + int out_idx = r_offset + c_offset; + pack_buf_cur[out_idx] = 0; + } } } } // for each group } template <typename T, typename accT> -void PackBMatrix<T, accT>::unpack(T* origin_buf) { - bool tr = (trans_ == matrix_op_t::Transpose); - for (int g = 0; g < this->numGroups(); ++g) { - T* out = BaseType::getBuf() + - g * - BaseType::packedBufferSize( - BaseType::numPackedRows(), BaseType::numPackedCols()); - for (int i = BaseType::packedRowStart(); - i < BaseType::packedRowStart() + BaseType::numPackedRows(); - ++i) { - int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * - (BaseType::blockRowSize() * BaseType::blockColSize()) + - (i % BaseType::blockRowSize() / row_interleave_) * - BaseType::blockColSize() * row_interleave_ + - i % row_interleave_; - - int c_start_offset = - (BaseType::packedColStart() / BaseType::blockColSize()) * - BaseType::blockRowSize() * BaseType::blockColSize() + - (BaseType::packedColStart() % BaseType::blockColSize()) * - row_interleave_; - - int c_idx_offset = 0; - int c_blk_offset = 0; - for (int j = BaseType::packedColStart(); - j < BaseType::packedColStart() + BaseType::numPackedCols(); - ++j) { - int c_offset = c_start_offset + - c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() + - c_idx_offset * row_interleave_; - - int out_idx = r_offset + c_offset; - - T val = out[out_idx]; - if (tr) { - origin_buf[i + (g * BaseType::numPackedCols() + j) * ld_] = val; - } else { - origin_buf[(g * BaseType::numPackedRows() + i) * ld_ + j] = val; - } +void PackBMatrix<T, accT>::pack( + const block_type_t& block, + const BlockingFactors* params) { + pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params); +} - c_idx_offset++; - if (c_idx_offset == BaseType::blockColSize()) { - c_idx_offset = 0; - c_blk_offset++; - } - } - } - } // for each group +template <typename T, typename accT> +void PackBMatrix<T, accT>::unpack( + T* origin_buf, + const BlockingFactors* params) { + block_type_t blockB{BaseType::packedRowStart(), + BaseType::numPackedRows(), + BaseType::packedColStart(), + BaseType::numPackedCols()}; + pack_unpack_(blockB, origin_buf, BaseType::getBuf(), false, params); } template <typename T, typename accT> @@ -374,7 +418,9 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const { } template <typename T, typename accT> -void PackBMatrix<T, accT>::printPackedMatrix(std::string name) { +void PackBMatrix<T, accT>::printPackedMatrix( + std::string name, + const BlockingFactors* params) { std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; @@ -382,33 +428,39 @@ void PackBMatrix<T, accT>::printPackedMatrix(std::string name) { << "[" << BaseType::blockRowSize() << ", " << BaseType::blockColSize() << "]" << std::endl; - T* out = BaseType::getBuf(); - - for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { - auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() - : BaseType::blockRowSize(); - for (auto nc = 0; nc < BaseType::blockCols(); ++nc) { - std::cout << "block:" << nr << ", " << nc << std::endl; - auto cols = (nc == BaseType::blockCols() - 1) ? BaseType::lastBcol() - : BaseType::blockColSize(); - for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_; - ++r) { - for (auto c = 0; c < cols * row_interleave_; ++c) { - T val = - out[nr * BaseType::blockCols() * BaseType::blockRowSize() * - BaseType::blockColSize() + - nc * BaseType::blockRowSize() * BaseType::blockColSize() + - r * BaseType::blockColSize() * row_interleave_ + c]; - if (std::is_integral<T>::value) { - // cast to int64 because cout doesn't print int8_t type directly - std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; - } else { - std::cout << std::setw(5) << val << " "; + for (int g = 0; g < BaseType::numGroups(); ++g) { + T* out = BaseType::getBuf() + + g * + BaseType::packedBufferSize( + BaseType::numPackedRows(), BaseType::numPackedCols(), params); + std::cout << "group: " << g << std::endl; + for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { + auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() + : BaseType::blockRowSize(); + for (auto nc = 0; nc < BaseType::blockCols(); ++nc) { + std::cout << "block:" << nr << ", " << nc << std::endl; + auto cols = (nc == BaseType::blockCols() - 1) + ? BaseType::lastBcol() + : BaseType::blockColSize(); + for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_; + ++r) { + for (auto c = 0; c < cols * row_interleave_; ++c) { + T val = + out[nr * BaseType::blockCols() * BaseType::blockRowSize() * + BaseType::blockColSize() + + nc * BaseType::blockRowSize() * BaseType::blockColSize() + + r * BaseType::blockColSize() * row_interleave_ + c]; + if (std::is_integral<T>::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } } + std::cout << std::endl; } std::cout << std::endl; } - std::cout << std::endl; } } } |