diff options
Diffstat (limited to 'include/fbgemm/QuantUtils.h')
-rw-r--r-- | include/fbgemm/QuantUtils.h | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index 43855d8..508ce7d 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -7,6 +7,7 @@ #include <limits> #include "FbgemmBuild.h" #include "QuantUtilsAvx2.h" +#include "Utils.h" namespace fbgemm { @@ -78,6 +79,40 @@ FBGEMM_API void Quantize( int len, const TensorQuantizationParams& qparams); +/* + * @brief Quantize floating point data in src to type T + * + * @tparam T output quantized data type (int8_t, uint8_t and int32_t are + * supported) + * + * @tparam T LAYOUT layout of input tensor in src. (KCX and KXC are supported) + * KCX corresponds to KCRS or KCTRS (for weight tensors with + * time dimension) + * KXC corresponds to KRSC or KTRSC (for weight tensors with + * time dimension) + * + * @params K Output channels for weight tensors + * @params C Number of channels + * @params X R*S or T*R*S + * @params G Groups (if G == C the function performs channelwise quantization; + * if 1 < G < C the function performs groupwise quantization; + * if G == 1 the function performs per tensor quantization;) + * @params scales floating point scales. + * Size should be equal G + * @params zero_points zero points (should be reprsentable in type T). + * Size should be equal G + */ +template <typename T, layout_t LAYOUT = layout_t::KCX> +FBGEMM_API void QuantizeGroupwise( + const float* src, + int K, + int C, + int X, + int G, + const float* scales, + const std::int32_t* zero_points, + T* dst); + template <typename T> FBGEMM_API float Dequantize(T src, const TensorQuantizationParams& qparams) { return qparams.scale * (src - qparams.zero_point); |