Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <kpu@users.noreply.github.com>2020-03-25 20:54:37 +0300
committerGitHub <noreply@github.com>2020-03-25 20:54:37 +0300
commit14bbbdbe83e46d6da7df29f8c7a7961039a58ec5 (patch)
tree30c06fa7bf022f5ee9ad4346af42c34c16c449eb
parent65176b06d3caea37bd0d9d5154686f073f37ad6b (diff)
parentd596aace3e80a223358672e7fbd178e83d2ab609 (diff)
Merge pull request #54 from kpu/stdQuantizer
Add standard deviation quantizer
-rw-r--r--avx2_gemm.h2
-rw-r--r--avx512_gemm.h2
-rw-r--r--intgemm.cc2
-rw-r--r--intgemm.h8
-rw-r--r--multiply.h49
-rw-r--r--ssse3_gemm.h2
-rw-r--r--test/quantize_test.cc84
7 files changed, 149 insertions, 0 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h
index 25866bb..529d628 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -240,6 +240,8 @@ struct AVX2_8bit {
INTGEMM_PREPAREBIASFOR8(__m256i, INTGEMM_AVX2, CPUType::AVX2)
+ INTGEMM_GETQUANTIZERSTD(__m256, INTGEMM_AVX2)
+
constexpr static const char *const kName = "8-bit AVX2";
static const CPUType kUses = CPUType::AVX2;
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 6286ccc..b8a561f 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -433,6 +433,8 @@ struct AVX512_8bit {
INTGEMM_PREPAREBIASFOR8(__m512i, INTGEMM_AVX512BW, CPUType::AVX2)
+ INTGEMM_GETQUANTIZERSTD(__m512, INTGEMM_AVX512BW)
+
constexpr static const char *const kName = "8-bit AVX512BW";
static const CPUType kUses = CPUType::AVX512BW;
diff --git a/intgemm.cc b/intgemm.cc
index 8838cdb..5014b47 100644
--- a/intgemm.cc
+++ b/intgemm.cc
@@ -40,6 +40,8 @@ const CPUType kCPU = ChooseCPU(CPUType::AVX512VNNI, CPUType::AVX512BW, CPUType::
float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(avx512f::MaxAbsolute, avx512f::MaxAbsolute, avx2::MaxAbsolute, sse2::MaxAbsolute, sse2::MaxAbsolute, Unsupported_MaxAbsolute);
+MeanStd (*GetQuantizerStd)(const float *begin, const float *end) = ChooseCPU(AVX512VNNI_8bit::GetQuantizerStd, AVX512_8bit::GetQuantizerStd, AVX2_8bit::GetQuantizerStd, SSSE3_8bit::GetQuantizerStd, Unsupported_8bit::GetQuantizerStd, Unsupported_8bit::GetQuantizerStd);
+
constexpr const char *const Unsupported_16bit::kName;
constexpr const char *const Unsupported_8bit::kName;
constexpr const char *const SSE2_16bit::kName;
diff --git a/intgemm.h b/intgemm.h
index 0c315fc..ac877e8 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -116,6 +116,11 @@ struct Unsupported_8bit {
static void Multiply8Shift(const uint8_t *, const int8_t *, Index, Index, Index, Callback) {
throw UnsupportedCPU();
}
+
+ static MeanStd GetQuantizerStd(const float *, const float *) {
+ throw UnsupportedCPU();
+ }
+
constexpr static const char *const kName = "8-bit Unsupported";
};
@@ -271,6 +276,9 @@ struct Int8 {
MultiplyImpl<Callback>::run(A, B, A_rows, width, B_cols, callback);
}
+ // Get a Quantization value that is equant to the mean of the data +N standard deviations. Use 2 by default
+ static MeanStd (*GetQuantizerStd)(const float *begin, const float *end);
+
static const char *const kName;
private:
diff --git a/multiply.h b/multiply.h
index 84d6737..ee2196e 100644
--- a/multiply.h
+++ b/multiply.h
@@ -6,8 +6,15 @@
#include "vec_traits.h"
#include "callbacks.h"
+#include <cmath> //sqrt
+
namespace intgemm {
+struct MeanStd {
+ float mean;
+ float stddev;
+};
+
INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) {
// Fold to just using the first 64 bits.
__m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
@@ -36,6 +43,22 @@ INTGEMM_AVX2 static inline __m256i PermuteSummer(__m256i pack0123, __m256i pack4
return _mm256_add_epi32(rev, blended);
}
+/* https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-float-vector-sum-on-x86 */
+INTGEMM_SSSE3 static inline float horizontalSum(__m128 a) {
+ __m128 shuf = _mm_movehdup_ps(a); // broadcast elements 3,1 to 2,0
+ __m128 sums = _mm_add_ps(a, shuf);
+ shuf = _mm_movehl_ps(shuf, sums); // high half -> low half
+ sums = _mm_add_ss(sums, shuf);
+ return _mm_cvtss_f32(sums);
+}
+
+INTGEMM_AVX2 static inline float horizontalSum(__m256 a) {
+ __m128 vlow = _mm256_castps256_ps128(a);
+ __m128 vhigh = _mm256_extractf128_ps(a, 1); // high 128
+ vlow = _mm_add_ps(vlow, vhigh); // add the low 128
+ return horizontalSum(vlow); // and inline the sse3 version, which is optimal for AVX
+}
+
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_AVX512BW static inline __m256i PermuteSummer(__m512i pack0123, __m512i pack4567) {
@@ -57,6 +80,10 @@ static inline INTGEMM_AVX512F float MaxFloat32(__m512 a) {
return MaxFloat32(max_ps(_mm512_castps512_ps256(a), upper));
}
+static inline INTGEMM_AVX512F float horizontalSum(__m512 a) {
+ return _mm512_reduce_add_ps(a);
+}
+
#endif
// Quantize function used for SSSE3 and AVX2.
@@ -615,4 +642,26 @@ target static inline float MaxAbsolute(const float *begin_float, const float *en
return ret; \
} \
+#define INTGEMM_GETQUANTIZERSTD(Register, target) \
+target static MeanStd GetQuantizerStd(const float *begin_float, const float *end_float) { \
+ /* Finds a quantizer value that is a certain number of standard deviations of the mean */ \
+ assert(end_float > begin_float); \
+ assert((end_float - begin_float) % (sizeof(Register) / sizeof(float)) == 0); \
+ size_t num_items = end_float - begin_float; \
+ const Register *begin = reinterpret_cast<const Register*>(begin_float); \
+ const Register *end = reinterpret_cast<const Register*>(end_float); \
+ Register squares = set1_ps<Register>(0); \
+ Register sums = set1_ps<Register>(0); \
+ for (; begin != end; begin++) { \
+ squares = add_ps(squares, mul_ps(*begin, *begin)); \
+ sums = add_ps(sums, *begin); \
+ } \
+ float squares_sum = horizontalSum(squares); \
+ float normal_sums = horizontalSum(sums); \
+ MeanStd ret; \
+ ret.mean = normal_sums/num_items; \
+ ret.stddev = std::sqrt((squares_sum/num_items) - (ret.mean*ret.mean)); \
+ return ret; \
+} \
+
} // namespace intgemm
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index fd3ab8c..df6144e 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -156,6 +156,8 @@ struct SSSE3_8bit {
INTGEMM_PREPAREBIASFOR8(__m128i, INTGEMM_SSSE3, CPUType::SSE2)
+ INTGEMM_GETQUANTIZERSTD(__m128, INTGEMM_SSSE3)
+
constexpr static const char *const kName = "8-bit SSSE3";
static const CPUType kUses = CPUType::SSSE3;
diff --git a/test/quantize_test.cc b/test/quantize_test.cc
index ee27261..7b5c6a3 100644
--- a/test/quantize_test.cc
+++ b/test/quantize_test.cc
@@ -30,6 +30,40 @@ void QuantizeRef(const float *input, int8_t *output, float quant_mult, std::size
}
}
+MeanStd QuantizerStddRef(AlignedVector<float>& vals, int num_items) {
+ float normal_sums = 0;
+ float squares_sum = 0;
+ std::for_each(vals.begin(), vals.end(), [&] (float n) {normal_sums+=n;});
+ std::for_each(vals.begin(), vals.end(), [&] (float n) {squares_sum+=n*n;});
+
+ MeanStd ret;
+ ret.mean = normal_sums/num_items;
+ ret.stddev = std::sqrt((squares_sum/num_items) - (ret.mean*ret.mean));
+ return ret;
+}
+
+template <class Backend>
+void testQuantizerStd(int num_items) {
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ AlignedVector<float> inputVec(num_items);
+
+ for (auto&& it : inputVec) {
+ it = dist(gen);
+ }
+
+ MeanStd reference = QuantizerStddRef(inputVec, num_items);
+ MeanStd fast = Backend::GetQuantizerStd(inputVec.begin(), inputVec.end());
+
+ float meanDifference = fabs(reference.mean - fast.mean);
+ float stdDifference = fabs(reference.stddev - fast.stddev);
+ float eps = 0.00002; //Accumulating horizontal sums can lead to errors.
+
+ CHECK_MESSAGE(meanDifference <= eps, "Reference mean: " << reference.mean << " actual: " << fast.mean);// /*Backend::kName << */" Mismatch:\n" << "Reference: " << reference << " Fast: " << fast << std::endl);
+ CHECK_MESSAGE(stdDifference <= eps, "Reference stddev: " << reference.stddev << " actual: " << fast.stddev);
+
+}
+
template <class I> bool IsOff(float from, I ref, I test) {
if (ref == test) return false;
if (ref - test > 1 && test - ref > 1) return true;
@@ -94,5 +128,55 @@ TEST_CASE ("Quantize AVX2", "[quantize]") {
}
#endif
+TEST_CASE("QuantizeStd SSSE3", "[quantizerSTD]") {
+ if (kCPU < CPUType::SSSE3) return;
+ testQuantizerStd<SSSE3_8bit>(64);
+ testQuantizerStd<SSSE3_8bit>(64);
+ testQuantizerStd<SSSE3_8bit>(256);
+ testQuantizerStd<SSSE3_8bit>(256);
+ testQuantizerStd<SSSE3_8bit>(2048);
+ testQuantizerStd<SSSE3_8bit>(2048);
+ testQuantizerStd<SSSE3_8bit>(65536);
+ testQuantizerStd<SSSE3_8bit>(65536);
+ testQuantizerStd<SSSE3_8bit>(81920);
+ testQuantizerStd<SSSE3_8bit>(81920);
+ testQuantizerStd<SSSE3_8bit>(120832);
+ testQuantizerStd<SSSE3_8bit>(120832);
+}
+
+TEST_CASE("QuantizeStd AVX2", "[quantizerSTD]") {
+ if (kCPU < CPUType::AVX2) return;
+ testQuantizerStd<AVX2_8bit>(64);
+ testQuantizerStd<AVX2_8bit>(64);
+ testQuantizerStd<AVX2_8bit>(256);
+ testQuantizerStd<AVX2_8bit>(256);
+ testQuantizerStd<AVX2_8bit>(2048);
+ testQuantizerStd<AVX2_8bit>(2048);
+ testQuantizerStd<AVX2_8bit>(65536);
+ testQuantizerStd<AVX2_8bit>(65536);
+ testQuantizerStd<AVX2_8bit>(81920);
+ testQuantizerStd<AVX2_8bit>(81920);
+ testQuantizerStd<AVX2_8bit>(120832);
+ testQuantizerStd<AVX2_8bit>(120832);
+}
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+TEST_CASE("QuantizeStd AVX512", "[quantizerSTD]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ testQuantizerStd<AVX512_8bit>(64);
+ testQuantizerStd<AVX512_8bit>(64);
+ testQuantizerStd<AVX512_8bit>(256);
+ testQuantizerStd<AVX512_8bit>(256);
+ testQuantizerStd<AVX512_8bit>(2048);
+ testQuantizerStd<AVX512_8bit>(2048);
+ testQuantizerStd<AVX512_8bit>(65536);
+ testQuantizerStd<AVX512_8bit>(65536);
+ testQuantizerStd<AVX512_8bit>(81920);
+ testQuantizerStd<AVX512_8bit>(81920);
+ testQuantizerStd<AVX512_8bit>(120832);
+ testQuantizerStd<AVX512_8bit>(120832);
+}
+#endif
+
} // namespace
} // namespace intgemm