Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Tsai <miketsai@fb.com>2019-06-15 03:04:25 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-06-15 03:10:29 +0300
commit604575ff5de717b2ee712190634840981a9c8fba (patch)
tree198817e0992810a7dff5ac0ed2a99a9b08834346
parent5e71d2c304663f3b4e50cee723b8e98a867d11ca (diff)
Update the logic of checking valid parameters.
Summary: Add the check on NR_MIN and fix ymm/zmm register checks. Reviewed By: dskhudia Differential Revision: D15772144 fbshipit-source-id: 11e2c67fb3d47c5570b38ceaf9828ced0e60e65b
-rw-r--r--include/fbgemm/Utils.h23
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc7
2 files changed, 20 insertions, 10 deletions
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index ef1d4ab..1a35aa1 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -119,10 +119,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
return false;
if (fbgemmHasAvx512Support()) {
- if (param->NR != 16)
+ if (param->NR_MIN != 16 || param->NR % param->NR_MIN)
return false;
} else if (fbgemmHasAvx2Support()) {
- if (param->NR != 8)
+ if (param->NR_MIN != 8 || param->NR % param->NR_MIN)
return false;
}
} else if (is_16bit) {
@@ -130,10 +130,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
return false;
if (fbgemmHasAvx512Support()) {
- if (param->NR != 32)
+ if (param->NR_MIN != 32 || param->NR % param->NR_MIN)
return false;
} else if (fbgemmHasAvx2Support()) {
- if (param->NR != 16)
+ if (param->NR_MIN != 16 || param->NR % param->NR_MIN)
return false;
}
}
@@ -143,10 +143,19 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
if (param->NCB % param->NR)
return false;
if (fbgemmHasAvx512Support()) {
- if (param->MR * (param->NCB / param->NR) > 24)
- return false;
+ if (is_32bit) {
+ // Zmm register usage for C
+ if (param->MR * (param->NR / param->NR_MIN) > 28)
+ return false;
+ } else if (is_16bit) {
+ // Zmm register usage for C + one row for loading B
+ if ((param->MR * (param->NR / param->NR_MIN) +
+ (param->NR / param->NR_MIN)) > 28)
+ return false;
+ }
+
} else if (fbgemmHasAvx2Support()) {
- if (param->MR * (param->NCB / param->NR) > 16)
+ if (param->MR * (param->NR / param->NR_MIN) > 12)
return false;
}
return true;
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index 505fec1..e5687eb 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -201,9 +201,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
int maxMRegs = mRegBlockSize;
int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
assert(
- maxMRegs * maxNRegs <= 24 &&
- "MR*(NR*ROW_INTERLEAVE*8/512) \
- must be <= 24(available registers constraint)");
+ (maxMRegs+1) * maxNRegs <= 28 &&
+ "number of zmm registers for C + one row for loading B: \
+ MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \
+ must be <= 28(available registers constraint)");
int mRegBlocks = mc / mRegBlockSize;
int mRegBlocksRem = mc % mRegBlockSize;