From 604575ff5de717b2ee712190634840981a9c8fba Mon Sep 17 00:00:00 2001 From: Mike Tsai Date: Fri, 14 Jun 2019 17:04:25 -0700 Subject: Update the logic of checking valid parameters. Summary: Add the check on NR_MIN and fix ymm/zmm register checks. Reviewed By: dskhudia Differential Revision: D15772144 fbshipit-source-id: 11e2c67fb3d47c5570b38ceaf9828ced0e60e65b --- include/fbgemm/Utils.h | 23 ++++++++++++++++------- src/GenerateKernelU8S8S32ACC16Avx512.cc | 7 ++++--- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index ef1d4ab..1a35aa1 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -119,10 +119,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { return false; if (fbgemmHasAvx512Support()) { - if (param->NR != 16) + if (param->NR_MIN != 16 || param->NR % param->NR_MIN) return false; } else if (fbgemmHasAvx2Support()) { - if (param->NR != 8) + if (param->NR_MIN != 8 || param->NR % param->NR_MIN) return false; } } else if (is_16bit) { @@ -130,10 +130,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { return false; if (fbgemmHasAvx512Support()) { - if (param->NR != 32) + if (param->NR_MIN != 32 || param->NR % param->NR_MIN) return false; } else if (fbgemmHasAvx2Support()) { - if (param->NR != 16) + if (param->NR_MIN != 16 || param->NR % param->NR_MIN) return false; } } @@ -143,10 +143,19 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { if (param->NCB % param->NR) return false; if (fbgemmHasAvx512Support()) { - if (param->MR * (param->NCB / param->NR) > 24) - return false; + if (is_32bit) { + // Zmm register usage for C + if (param->MR * (param->NR / param->NR_MIN) > 28) + return false; + } else if (is_16bit) { + // Zmm register usage for C + one row for loading B + if ((param->MR * (param->NR / param->NR_MIN) + + (param->NR / param->NR_MIN)) > 28) + return false; + } + } else if (fbgemmHasAvx2Support()) { - if (param->MR * (param->NCB / param->NR) > 16) + if (param->MR * (param->NR / param->NR_MIN) > 12) return false; } return true; diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index 505fec1..e5687eb 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -201,9 +201,10 @@ CodeGenBase::getOrCreate( int maxMRegs = mRegBlockSize; int maxNRegs = nRegBlockSize * row_interleave / VLEN_; assert( - maxMRegs * maxNRegs <= 24 && - "MR*(NR*ROW_INTERLEAVE*8/512) \ - must be <= 24(available registers constraint)"); + (maxMRegs+1) * maxNRegs <= 28 && + "number of zmm registers for C + one row for loading B: \ + MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \ + must be <= 28(available registers constraint)"); int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; -- cgit v1.2.3