diff options
Diffstat (limited to 'include/fbgemm/Utils.h')
-rw-r--r-- | include/fbgemm/Utils.h | 50 |
1 files changed, 41 insertions, 9 deletions
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 9f8e1ee..3976790 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -5,6 +5,7 @@ * LICENSE file in the root directory of this source tree. */ #pragma once +#include <array> #include <string> #include <type_traits> #include "FbgemmBuild.h" @@ -39,12 +40,12 @@ enum class matrix_op_t { NoTranspose, Transpose }; /** * @brief Typed enum for supported instruction sets. */ -enum class inst_set_t { anyarch, avx2, avx512 }; +enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni }; /** * @brief Typed enum for optimized paths for convolutions */ -enum class optimized_conv_t { depthwise, groupwise, im2col }; +enum class optimized_conv_t { depthwise, groupwise, pointwise, im2col }; /** * @brief Typed enum for implementation type. @@ -54,6 +55,13 @@ enum class optimized_conv_t { depthwise, groupwise, im2col }; enum class impl_type_t { ref, opt }; /** + * @brief Typed enum to specify data layout. + * KCX can be KCRS format or KCTRS format (e.g., for 3-D convolutions) + * KXC can be KRSC format or KTRSC format (e.g., for 3-D convolutions) + */ +enum class layout_t { KCX, KXC }; + +/** * @brief A function to compare data in two buffers for closeness/equality. */ template <typename T> @@ -103,6 +111,11 @@ FBGEMM_API bool fbgemmHasAvx512Support(); FBGEMM_API bool fbgemmHasAvx2Support(); /** + * @brief Are we running on a AVX512_VNNI supported cpu? + */ +FBGEMM_API bool fbgemmHasAvx512VnniSupport(); + +/** * @brief Helper struct to enable autotuning of FBGEMM packing and kernels. * * This structure is optional. If not used, the default values for these @@ -119,6 +132,16 @@ struct FBGEMM_API BlockingFactors { int NCB; }; +template <int SIZE, typename T = std::int32_t> +FBGEMM_API std::string arrayToString(const std::array<T, SIZE>& inp) { + std::string out = "["; + for (int i = 0; i < SIZE; ++i) { + out += std::to_string(inp[i]); + out += (i != SIZE - 1) ? std::string(", ") : std::string("]"); + } + return out; +} + template <typename accT = std::int32_t> FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { constexpr bool is_32bit = std::is_same<accT, int32_t>::value; @@ -129,10 +152,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { return false; if (fbgemmHasAvx512Support()) { - if (param->NR != 16) + if (param->NR_MIN != 16 || param->NR % param->NR_MIN) return false; } else if (fbgemmHasAvx2Support()) { - if (param->NR != 8) + if (param->NR_MIN != 8 || param->NR % param->NR_MIN) return false; } } else if (is_16bit) { @@ -140,10 +163,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { return false; if (fbgemmHasAvx512Support()) { - if (param->NR != 32) + if (param->NR_MIN != 32 || param->NR % param->NR_MIN) return false; } else if (fbgemmHasAvx2Support()) { - if (param->NR != 16) + if (param->NR_MIN != 16 || param->NR % param->NR_MIN) return false; } } @@ -153,10 +176,19 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { if (param->NCB % param->NR) return false; if (fbgemmHasAvx512Support()) { - if (param->MR * (param->NCB / param->NR) > 24) - return false; + if (is_32bit) { + // Zmm register usage for C + if (param->MR * (param->NR / param->NR_MIN) > 28) + return false; + } else if (is_16bit) { + // Zmm register usage for C + one row for loading B + if ((param->MR * (param->NR / param->NR_MIN) + + (param->NR / param->NR_MIN)) > 28) + return false; + } + } else if (fbgemmHasAvx2Support()) { - if (param->MR * (param->NCB / param->NR) > 16) + if (param->MR * (param->NR / param->NR_MIN) > 12) return false; } return true; |