Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/PackMatrix.cc')
-rw-r--r--src/PackMatrix.cc44
1 files changed, 34 insertions, 10 deletions
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
index 316fc06..e93b97c 100644
--- a/src/PackMatrix.cc
+++ b/src/PackMatrix.cc
@@ -18,33 +18,57 @@ PackMatrix<PT, inpType, accType>::PackMatrix(
int32_t rows,
int32_t cols,
inpType* buf,
- int groups)
+ int groups,
+ const BlockingFactors* params)
: buf_(buf), nrows_(rows), ncols_(cols), G_(groups) {
bufAllocatedHere_ = false;
+ blocking_params = params;
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
}
template <typename PT, typename inpType, typename accType>
-int PackMatrix<PT, inpType, accType>::packedBufferSize(int rows, int cols) {
+int PackMatrix<PT, inpType, accType>::packedBufferSize(
+ int rows,
+ int cols,
+ const BlockingFactors* params) {
+ int MCB, KCB, NCB;
+ if (params) {
+ MCB = params->MCB;
+ NCB = params->NCB;
+ KCB = params->KCB;
+ } else {
+ if (fbgemmHasAvx512Support()) {
+ MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB;
+ NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
+ KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
+ } else if (fbgemmHasAvx2Support()) {
+ MCB = PackingTraits<inpType, accType, inst_set_t::avx2>::MCB;
+ NCB = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
+ KCB = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ return -1;
+ }
+ }
+
if (fbgemmHasAvx512Support()) {
if (isA()) {
- return PackingTraits<inpType, accType, inst_set_t::avx512>::MCB *
- PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
+ return MCB * KCB;
} else {
- int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
- int colBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
+ int rowBlock = KCB;
+ int colBlock = NCB;
return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
(((cols + colBlock - 1) / colBlock) * colBlock);
}
} else if (fbgemmHasAvx2Support()) {
if (isA()) {
- return PackingTraits<inpType, accType, inst_set_t::avx2>::MCB *
- PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
+ return MCB * KCB;
} else {
- int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
- int colBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
+ int rowBlock = KCB;
+ int colBlock = NCB;
return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
(((cols + colBlock - 1) / colBlock) * colBlock);
}