Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaya Khudia <dskhudia@fb.com>2019-06-20 22:13:35 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-06-20 22:21:51 +0300
commit5b64af1469cf629aa7beb934eb898fd1e0b02719 (patch)
treedddef8da6e597f1c118a18cfe5ff421e97df0a88 /include
parent604575ff5de717b2ee712190634840981a9c8fba (diff)
Per channel and groupwise quantization (#99)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/99 A function to do per channel and groupwise quantization Reviewed By: jspark1105 Differential Revision: D15567272 fbshipit-source-id: e2f326ea7c7463b5c47b3f590e003344a9e41960
Diffstat (limited to 'include')
-rw-r--r--include/fbgemm/QuantUtils.h35
-rw-r--r--include/fbgemm/Utils.h7
2 files changed, 42 insertions, 0 deletions
diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h
index 43855d8..508ce7d 100644
--- a/include/fbgemm/QuantUtils.h
+++ b/include/fbgemm/QuantUtils.h
@@ -7,6 +7,7 @@
#include <limits>
#include "FbgemmBuild.h"
#include "QuantUtilsAvx2.h"
+#include "Utils.h"
namespace fbgemm {
@@ -78,6 +79,40 @@ FBGEMM_API void Quantize(
int len,
const TensorQuantizationParams& qparams);
+/*
+ * @brief Quantize floating point data in src to type T
+ *
+ * @tparam T output quantized data type (int8_t, uint8_t and int32_t are
+ * supported)
+ *
+ * @tparam T LAYOUT layout of input tensor in src. (KCX and KXC are supported)
+ * KCX corresponds to KCRS or KCTRS (for weight tensors with
+ * time dimension)
+ * KXC corresponds to KRSC or KTRSC (for weight tensors with
+ * time dimension)
+ *
+ * @params K Output channels for weight tensors
+ * @params C Number of channels
+ * @params X R*S or T*R*S
+ * @params G Groups (if G == C the function performs channelwise quantization;
+ * if 1 < G < C the function performs groupwise quantization;
+ * if G == 1 the function performs per tensor quantization;)
+ * @params scales floating point scales.
+ * Size should be equal G
+ * @params zero_points zero points (should be reprsentable in type T).
+ * Size should be equal G
+ */
+template <typename T, layout_t LAYOUT = layout_t::KCX>
+FBGEMM_API void QuantizeGroupwise(
+ const float* src,
+ int K,
+ int C,
+ int X,
+ int G,
+ const float* scales,
+ const std::int32_t* zero_points,
+ T* dst);
+
template <typename T>
FBGEMM_API float Dequantize(T src, const TensorQuantizationParams& qparams) {
return qparams.scale * (src - qparams.zero_point);
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index 1a35aa1..636abc7 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -44,6 +44,13 @@ enum class optimized_conv_t { depthwise, groupwise, im2col };
enum class impl_type_t { ref, opt };
/**
+ * @brief Typed enum to specify data layout.
+ * KCX can be KCRS format or KCTRS format (e.g., for 3-D convolutions)
+ * KXC can be KRSC format or KTRSC format (e.g., for 3-D convolutions)
+ */
+enum class layout_t { KCX, KXC };
+
+/**
* @brief A function to compare data in two buffers for closeness/equality.
*/
template <typename T>