diff options
author | Kenneth Heafield <kpu@users.noreply.github.com> | 2020-03-25 20:54:37 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-25 20:54:37 +0300 |
commit | 14bbbdbe83e46d6da7df29f8c7a7961039a58ec5 (patch) | |
tree | 30c06fa7bf022f5ee9ad4346af42c34c16c449eb | |
parent | 65176b06d3caea37bd0d9d5154686f073f37ad6b (diff) | |
parent | d596aace3e80a223358672e7fbd178e83d2ab609 (diff) |
Merge pull request #54 from kpu/stdQuantizer
Add standard deviation quantizer
-rw-r--r-- | avx2_gemm.h | 2 | ||||
-rw-r--r-- | avx512_gemm.h | 2 | ||||
-rw-r--r-- | intgemm.cc | 2 | ||||
-rw-r--r-- | intgemm.h | 8 | ||||
-rw-r--r-- | multiply.h | 49 | ||||
-rw-r--r-- | ssse3_gemm.h | 2 | ||||
-rw-r--r-- | test/quantize_test.cc | 84 |
7 files changed, 149 insertions, 0 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h index 25866bb..529d628 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -240,6 +240,8 @@ struct AVX2_8bit { INTGEMM_PREPAREBIASFOR8(__m256i, INTGEMM_AVX2, CPUType::AVX2) + INTGEMM_GETQUANTIZERSTD(__m256, INTGEMM_AVX2) + constexpr static const char *const kName = "8-bit AVX2"; static const CPUType kUses = CPUType::AVX2; diff --git a/avx512_gemm.h b/avx512_gemm.h index 6286ccc..b8a561f 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -433,6 +433,8 @@ struct AVX512_8bit { INTGEMM_PREPAREBIASFOR8(__m512i, INTGEMM_AVX512BW, CPUType::AVX2) + INTGEMM_GETQUANTIZERSTD(__m512, INTGEMM_AVX512BW) + constexpr static const char *const kName = "8-bit AVX512BW"; static const CPUType kUses = CPUType::AVX512BW; @@ -40,6 +40,8 @@ const CPUType kCPU = ChooseCPU(CPUType::AVX512VNNI, CPUType::AVX512BW, CPUType:: float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(avx512f::MaxAbsolute, avx512f::MaxAbsolute, avx2::MaxAbsolute, sse2::MaxAbsolute, sse2::MaxAbsolute, Unsupported_MaxAbsolute); +MeanStd (*GetQuantizerStd)(const float *begin, const float *end) = ChooseCPU(AVX512VNNI_8bit::GetQuantizerStd, AVX512_8bit::GetQuantizerStd, AVX2_8bit::GetQuantizerStd, SSSE3_8bit::GetQuantizerStd, Unsupported_8bit::GetQuantizerStd, Unsupported_8bit::GetQuantizerStd); + constexpr const char *const Unsupported_16bit::kName; constexpr const char *const Unsupported_8bit::kName; constexpr const char *const SSE2_16bit::kName; @@ -116,6 +116,11 @@ struct Unsupported_8bit { static void Multiply8Shift(const uint8_t *, const int8_t *, Index, Index, Index, Callback) { throw UnsupportedCPU(); } + + static MeanStd GetQuantizerStd(const float *, const float *) { + throw UnsupportedCPU(); + } + constexpr static const char *const kName = "8-bit Unsupported"; }; @@ -271,6 +276,9 @@ struct Int8 { MultiplyImpl<Callback>::run(A, B, A_rows, width, B_cols, callback); } + // Get a Quantization value that is equant to the mean of the data +N standard deviations. Use 2 by default + static MeanStd (*GetQuantizerStd)(const float *begin, const float *end); + static const char *const kName; private: @@ -6,8 +6,15 @@ #include "vec_traits.h" #include "callbacks.h" +#include <cmath> //sqrt + namespace intgemm { +struct MeanStd { + float mean; + float stddev; +}; + INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) { // Fold to just using the first 64 bits. __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2); @@ -36,6 +43,22 @@ INTGEMM_AVX2 static inline __m256i PermuteSummer(__m256i pack0123, __m256i pack4 return _mm256_add_epi32(rev, blended); } +/* https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-float-vector-sum-on-x86 */ +INTGEMM_SSSE3 static inline float horizontalSum(__m128 a) { + __m128 shuf = _mm_movehdup_ps(a); // broadcast elements 3,1 to 2,0 + __m128 sums = _mm_add_ps(a, shuf); + shuf = _mm_movehl_ps(shuf, sums); // high half -> low half + sums = _mm_add_ss(sums, shuf); + return _mm_cvtss_f32(sums); +} + +INTGEMM_AVX2 static inline float horizontalSum(__m256 a) { + __m128 vlow = _mm256_castps256_ps128(a); + __m128 vhigh = _mm256_extractf128_ps(a, 1); // high 128 + vlow = _mm_add_ps(vlow, vhigh); // add the low 128 + return horizontalSum(vlow); // and inline the sse3 version, which is optimal for AVX +} + #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ INTGEMM_AVX512BW static inline __m256i PermuteSummer(__m512i pack0123, __m512i pack4567) { @@ -57,6 +80,10 @@ static inline INTGEMM_AVX512F float MaxFloat32(__m512 a) { return MaxFloat32(max_ps(_mm512_castps512_ps256(a), upper)); } +static inline INTGEMM_AVX512F float horizontalSum(__m512 a) { + return _mm512_reduce_add_ps(a); +} + #endif // Quantize function used for SSSE3 and AVX2. @@ -615,4 +642,26 @@ target static inline float MaxAbsolute(const float *begin_float, const float *en return ret; \ } \ +#define INTGEMM_GETQUANTIZERSTD(Register, target) \ +target static MeanStd GetQuantizerStd(const float *begin_float, const float *end_float) { \ + /* Finds a quantizer value that is a certain number of standard deviations of the mean */ \ + assert(end_float > begin_float); \ + assert((end_float - begin_float) % (sizeof(Register) / sizeof(float)) == 0); \ + size_t num_items = end_float - begin_float; \ + const Register *begin = reinterpret_cast<const Register*>(begin_float); \ + const Register *end = reinterpret_cast<const Register*>(end_float); \ + Register squares = set1_ps<Register>(0); \ + Register sums = set1_ps<Register>(0); \ + for (; begin != end; begin++) { \ + squares = add_ps(squares, mul_ps(*begin, *begin)); \ + sums = add_ps(sums, *begin); \ + } \ + float squares_sum = horizontalSum(squares); \ + float normal_sums = horizontalSum(sums); \ + MeanStd ret; \ + ret.mean = normal_sums/num_items; \ + ret.stddev = std::sqrt((squares_sum/num_items) - (ret.mean*ret.mean)); \ + return ret; \ +} \ + } // namespace intgemm diff --git a/ssse3_gemm.h b/ssse3_gemm.h index fd3ab8c..df6144e 100644 --- a/ssse3_gemm.h +++ b/ssse3_gemm.h @@ -156,6 +156,8 @@ struct SSSE3_8bit { INTGEMM_PREPAREBIASFOR8(__m128i, INTGEMM_SSSE3, CPUType::SSE2) + INTGEMM_GETQUANTIZERSTD(__m128, INTGEMM_SSSE3) + constexpr static const char *const kName = "8-bit SSSE3"; static const CPUType kUses = CPUType::SSSE3; diff --git a/test/quantize_test.cc b/test/quantize_test.cc index ee27261..7b5c6a3 100644 --- a/test/quantize_test.cc +++ b/test/quantize_test.cc @@ -30,6 +30,40 @@ void QuantizeRef(const float *input, int8_t *output, float quant_mult, std::size } } +MeanStd QuantizerStddRef(AlignedVector<float>& vals, int num_items) { + float normal_sums = 0; + float squares_sum = 0; + std::for_each(vals.begin(), vals.end(), [&] (float n) {normal_sums+=n;}); + std::for_each(vals.begin(), vals.end(), [&] (float n) {squares_sum+=n*n;}); + + MeanStd ret; + ret.mean = normal_sums/num_items; + ret.stddev = std::sqrt((squares_sum/num_items) - (ret.mean*ret.mean)); + return ret; +} + +template <class Backend> +void testQuantizerStd(int num_items) { + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + AlignedVector<float> inputVec(num_items); + + for (auto&& it : inputVec) { + it = dist(gen); + } + + MeanStd reference = QuantizerStddRef(inputVec, num_items); + MeanStd fast = Backend::GetQuantizerStd(inputVec.begin(), inputVec.end()); + + float meanDifference = fabs(reference.mean - fast.mean); + float stdDifference = fabs(reference.stddev - fast.stddev); + float eps = 0.00002; //Accumulating horizontal sums can lead to errors. + + CHECK_MESSAGE(meanDifference <= eps, "Reference mean: " << reference.mean << " actual: " << fast.mean);// /*Backend::kName << */" Mismatch:\n" << "Reference: " << reference << " Fast: " << fast << std::endl); + CHECK_MESSAGE(stdDifference <= eps, "Reference stddev: " << reference.stddev << " actual: " << fast.stddev); + +} + template <class I> bool IsOff(float from, I ref, I test) { if (ref == test) return false; if (ref - test > 1 && test - ref > 1) return true; @@ -94,5 +128,55 @@ TEST_CASE ("Quantize AVX2", "[quantize]") { } #endif +TEST_CASE("QuantizeStd SSSE3", "[quantizerSTD]") { + if (kCPU < CPUType::SSSE3) return; + testQuantizerStd<SSSE3_8bit>(64); + testQuantizerStd<SSSE3_8bit>(64); + testQuantizerStd<SSSE3_8bit>(256); + testQuantizerStd<SSSE3_8bit>(256); + testQuantizerStd<SSSE3_8bit>(2048); + testQuantizerStd<SSSE3_8bit>(2048); + testQuantizerStd<SSSE3_8bit>(65536); + testQuantizerStd<SSSE3_8bit>(65536); + testQuantizerStd<SSSE3_8bit>(81920); + testQuantizerStd<SSSE3_8bit>(81920); + testQuantizerStd<SSSE3_8bit>(120832); + testQuantizerStd<SSSE3_8bit>(120832); +} + +TEST_CASE("QuantizeStd AVX2", "[quantizerSTD]") { + if (kCPU < CPUType::AVX2) return; + testQuantizerStd<AVX2_8bit>(64); + testQuantizerStd<AVX2_8bit>(64); + testQuantizerStd<AVX2_8bit>(256); + testQuantizerStd<AVX2_8bit>(256); + testQuantizerStd<AVX2_8bit>(2048); + testQuantizerStd<AVX2_8bit>(2048); + testQuantizerStd<AVX2_8bit>(65536); + testQuantizerStd<AVX2_8bit>(65536); + testQuantizerStd<AVX2_8bit>(81920); + testQuantizerStd<AVX2_8bit>(81920); + testQuantizerStd<AVX2_8bit>(120832); + testQuantizerStd<AVX2_8bit>(120832); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("QuantizeStd AVX512", "[quantizerSTD]") { + if (kCPU < CPUType::AVX512BW) return; + testQuantizerStd<AVX512_8bit>(64); + testQuantizerStd<AVX512_8bit>(64); + testQuantizerStd<AVX512_8bit>(256); + testQuantizerStd<AVX512_8bit>(256); + testQuantizerStd<AVX512_8bit>(2048); + testQuantizerStd<AVX512_8bit>(2048); + testQuantizerStd<AVX512_8bit>(65536); + testQuantizerStd<AVX512_8bit>(65536); + testQuantizerStd<AVX512_8bit>(81920); + testQuantizerStd<AVX512_8bit>(81920); + testQuantizerStd<AVX512_8bit>(120832); + testQuantizerStd<AVX512_8bit>(120832); +} +#endif + } // namespace } // namespace intgemm |