diff options
Diffstat (limited to 'include/fbgemm')
-rw-r--r-- | include/fbgemm/QuantUtils.h | 35 | ||||
-rw-r--r-- | include/fbgemm/Utils.h | 7 |
2 files changed, 42 insertions, 0 deletions
diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index 43855d8..508ce7d 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -7,6 +7,7 @@ #include <limits> #include "FbgemmBuild.h" #include "QuantUtilsAvx2.h" +#include "Utils.h" namespace fbgemm { @@ -78,6 +79,40 @@ FBGEMM_API void Quantize( int len, const TensorQuantizationParams& qparams); +/* + * @brief Quantize floating point data in src to type T + * + * @tparam T output quantized data type (int8_t, uint8_t and int32_t are + * supported) + * + * @tparam T LAYOUT layout of input tensor in src. (KCX and KXC are supported) + * KCX corresponds to KCRS or KCTRS (for weight tensors with + * time dimension) + * KXC corresponds to KRSC or KTRSC (for weight tensors with + * time dimension) + * + * @params K Output channels for weight tensors + * @params C Number of channels + * @params X R*S or T*R*S + * @params G Groups (if G == C the function performs channelwise quantization; + * if 1 < G < C the function performs groupwise quantization; + * if G == 1 the function performs per tensor quantization;) + * @params scales floating point scales. + * Size should be equal G + * @params zero_points zero points (should be reprsentable in type T). + * Size should be equal G + */ +template <typename T, layout_t LAYOUT = layout_t::KCX> +FBGEMM_API void QuantizeGroupwise( + const float* src, + int K, + int C, + int X, + int G, + const float* scales, + const std::int32_t* zero_points, + T* dst); + template <typename T> FBGEMM_API float Dequantize(T src, const TensorQuantizationParams& qparams) { return qparams.scale * (src - qparams.zero_point); diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 1a35aa1..636abc7 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -44,6 +44,13 @@ enum class optimized_conv_t { depthwise, groupwise, im2col }; enum class impl_type_t { ref, opt }; /** + * @brief Typed enum to specify data layout. + * KCX can be KCRS format or KCTRS format (e.g., for 3-D convolutions) + * KXC can be KRSC format or KTRSC format (e.g., for 3-D convolutions) + */ +enum class layout_t { KCX, KXC }; + +/** * @brief A function to compare data in two buffers for closeness/equality. */ template <typename T> |