diff options
Diffstat (limited to 'src/PackMatrix.cc')
-rw-r--r-- | src/PackMatrix.cc | 44 |
1 files changed, 34 insertions, 10 deletions
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc index 316fc06..e93b97c 100644 --- a/src/PackMatrix.cc +++ b/src/PackMatrix.cc @@ -18,33 +18,57 @@ PackMatrix<PT, inpType, accType>::PackMatrix( int32_t rows, int32_t cols, inpType* buf, - int groups) + int groups, + const BlockingFactors* params) : buf_(buf), nrows_(rows), ncols_(cols), G_(groups) { bufAllocatedHere_ = false; + blocking_params = params; if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } } template <typename PT, typename inpType, typename accType> -int PackMatrix<PT, inpType, accType>::packedBufferSize(int rows, int cols) { +int PackMatrix<PT, inpType, accType>::packedBufferSize( + int rows, + int cols, + const BlockingFactors* params) { + int MCB, KCB, NCB; + if (params) { + MCB = params->MCB; + NCB = params->NCB; + KCB = params->KCB; + } else { + if (fbgemmHasAvx512Support()) { + MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB; + NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB; + KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; + } else if (fbgemmHasAvx2Support()) { + MCB = PackingTraits<inpType, accType, inst_set_t::avx2>::MCB; + NCB = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB; + KCB = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + return -1; + } + } + if (fbgemmHasAvx512Support()) { if (isA()) { - return PackingTraits<inpType, accType, inst_set_t::avx512>::MCB * - PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; + return MCB * KCB; } else { - int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; - int colBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB; + int rowBlock = KCB; + int colBlock = NCB; return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * (((cols + colBlock - 1) / colBlock) * colBlock); } } else if (fbgemmHasAvx2Support()) { if (isA()) { - return PackingTraits<inpType, accType, inst_set_t::avx2>::MCB * - PackingTraits<inpType, accType, inst_set_t::avx2>::KCB; + return MCB * KCB; } else { - int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB; - int colBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB; + int rowBlock = KCB; + int colBlock = NCB; return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * (((cols + colBlock - 1) / colBlock) * colBlock); } |