Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'include/fbgemm/Utils.h')
-rw-r--r--include/fbgemm/Utils.h50
1 files changed, 41 insertions, 9 deletions
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index 9f8e1ee..3976790 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -5,6 +5,7 @@
* LICENSE file in the root directory of this source tree.
*/
#pragma once
+#include <array>
#include <string>
#include <type_traits>
#include "FbgemmBuild.h"
@@ -39,12 +40,12 @@ enum class matrix_op_t { NoTranspose, Transpose };
/**
* @brief Typed enum for supported instruction sets.
*/
-enum class inst_set_t { anyarch, avx2, avx512 };
+enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni };
/**
* @brief Typed enum for optimized paths for convolutions
*/
-enum class optimized_conv_t { depthwise, groupwise, im2col };
+enum class optimized_conv_t { depthwise, groupwise, pointwise, im2col };
/**
* @brief Typed enum for implementation type.
@@ -54,6 +55,13 @@ enum class optimized_conv_t { depthwise, groupwise, im2col };
enum class impl_type_t { ref, opt };
/**
+ * @brief Typed enum to specify data layout.
+ * KCX can be KCRS format or KCTRS format (e.g., for 3-D convolutions)
+ * KXC can be KRSC format or KTRSC format (e.g., for 3-D convolutions)
+ */
+enum class layout_t { KCX, KXC };
+
+/**
* @brief A function to compare data in two buffers for closeness/equality.
*/
template <typename T>
@@ -103,6 +111,11 @@ FBGEMM_API bool fbgemmHasAvx512Support();
FBGEMM_API bool fbgemmHasAvx2Support();
/**
+ * @brief Are we running on a AVX512_VNNI supported cpu?
+ */
+FBGEMM_API bool fbgemmHasAvx512VnniSupport();
+
+/**
* @brief Helper struct to enable autotuning of FBGEMM packing and kernels.
*
* This structure is optional. If not used, the default values for these
@@ -119,6 +132,16 @@ struct FBGEMM_API BlockingFactors {
int NCB;
};
+template <int SIZE, typename T = std::int32_t>
+FBGEMM_API std::string arrayToString(const std::array<T, SIZE>& inp) {
+ std::string out = "[";
+ for (int i = 0; i < SIZE; ++i) {
+ out += std::to_string(inp[i]);
+ out += (i != SIZE - 1) ? std::string(", ") : std::string("]");
+ }
+ return out;
+}
+
template <typename accT = std::int32_t>
FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
constexpr bool is_32bit = std::is_same<accT, int32_t>::value;
@@ -129,10 +152,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
return false;
if (fbgemmHasAvx512Support()) {
- if (param->NR != 16)
+ if (param->NR_MIN != 16 || param->NR % param->NR_MIN)
return false;
} else if (fbgemmHasAvx2Support()) {
- if (param->NR != 8)
+ if (param->NR_MIN != 8 || param->NR % param->NR_MIN)
return false;
}
} else if (is_16bit) {
@@ -140,10 +163,10 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
return false;
if (fbgemmHasAvx512Support()) {
- if (param->NR != 32)
+ if (param->NR_MIN != 32 || param->NR % param->NR_MIN)
return false;
} else if (fbgemmHasAvx2Support()) {
- if (param->NR != 16)
+ if (param->NR_MIN != 16 || param->NR % param->NR_MIN)
return false;
}
}
@@ -153,10 +176,19 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
if (param->NCB % param->NR)
return false;
if (fbgemmHasAvx512Support()) {
- if (param->MR * (param->NCB / param->NR) > 24)
- return false;
+ if (is_32bit) {
+ // Zmm register usage for C
+ if (param->MR * (param->NR / param->NR_MIN) > 28)
+ return false;
+ } else if (is_16bit) {
+ // Zmm register usage for C + one row for loading B
+ if ((param->MR * (param->NR / param->NR_MIN) +
+ (param->NR / param->NR_MIN)) > 28)
+ return false;
+ }
+
} else if (fbgemmHasAvx2Support()) {
- if (param->MR * (param->NCB / param->NR) > 16)
+ if (param->MR * (param->NR / param->NR_MIN) > 12)
return false;
}
return true;