Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/PackMatrix.cc')
-rw-r--r--src/PackMatrix.cc40
1 files changed, 16 insertions, 24 deletions
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
index 33227fb..c7503dd 100644
--- a/src/PackMatrix.cc
+++ b/src/PackMatrix.cc
@@ -36,45 +36,37 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
+ if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ assert(0 && "unknown architecure");
+ }
+
int MCB, KCB, NCB;
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
- MCB = params->MCB;
- NCB = params->NCB;
- KCB = params->KCB;
- } else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
- }
+ MCB = params->MCB;
+ NCB = params->NCB;
+ KCB = params->KCB;
} else {
if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
- } else if (fbgemmHasAvx2Support()) {
+ } else {
+ // AVX2
MCB = PackingTraits<inpType, accType, inst_set_t::avx2>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
- } else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
- return -1;
}
}
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
- if (isA()) {
- return MCB * KCB;
- } else {
- int rowBlock = KCB;
- int colBlock = NCB;
- return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
- (((cols + colBlock - 1) / colBlock) * colBlock);
- }
+ if (isA()) {
+ return MCB * KCB;
} else {
- // TODO: Have default slower path
- assert(0 && "unsupported architecure");
+ int rowBlock = KCB;
+ int colBlock = NCB;
+ return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
+ (((cols + colBlock - 1) / colBlock) * colBlock);
}
+
return -1;
}