diff options
-rw-r--r-- | avx2_gemm.h | 4 | ||||
-rw-r--r-- | avx512_gemm.h | 4 | ||||
-rw-r--r-- | intgemm.cc | 2 | ||||
-rw-r--r-- | intgemm.h | 13 | ||||
-rw-r--r-- | sse2_gemm.h | 2 | ||||
-rw-r--r-- | ssse3_gemm.h | 2 | ||||
-rw-r--r-- | test/quantize_test.cc | 76 |
7 files changed, 51 insertions, 52 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h index 529d628..c1c4616 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -192,6 +192,8 @@ class QuantizeTile8 { // Technically only requires AVX INTGEMM_MAXABSOLUTE(__m256, INTGEMM_AVX2) +INTGEMM_GETQUANTIZERSTD(__m256, INTGEMM_AVX2) + } // namespace struct AVX2_8bit { @@ -239,8 +241,6 @@ struct AVX2_8bit { INTGEMM_MULTIPLY8SHIFT(__m256i, INTGEMM_AVX2, CPUType::AVX2) INTGEMM_PREPAREBIASFOR8(__m256i, INTGEMM_AVX2, CPUType::AVX2) - - INTGEMM_GETQUANTIZERSTD(__m256, INTGEMM_AVX2) constexpr static const char *const kName = "8-bit AVX2"; diff --git a/avx512_gemm.h b/avx512_gemm.h index b8a561f..623e21a 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -159,6 +159,8 @@ class QuantizeTile8 { /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ INTGEMM_MAXABSOLUTE(__m512, INTGEMM_AVX512BW) +INTGEMM_GETQUANTIZERSTD(__m512, INTGEMM_AVX512BW) + } // namespace struct AVX512_16bit { @@ -433,8 +435,6 @@ struct AVX512_8bit { INTGEMM_PREPAREBIASFOR8(__m512i, INTGEMM_AVX512BW, CPUType::AVX2) - INTGEMM_GETQUANTIZERSTD(__m512, INTGEMM_AVX512BW) - constexpr static const char *const kName = "8-bit AVX512BW"; static const CPUType kUses = CPUType::AVX512BW; @@ -40,7 +40,7 @@ const CPUType kCPU = ChooseCPU(CPUType::AVX512VNNI, CPUType::AVX512BW, CPUType:: float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(avx512f::MaxAbsolute, avx512f::MaxAbsolute, avx2::MaxAbsolute, sse2::MaxAbsolute, sse2::MaxAbsolute, Unsupported_MaxAbsolute); -MeanStd (*GetQuantizerStd)(const float *begin, const float *end) = ChooseCPU(AVX512VNNI_8bit::GetQuantizerStd, AVX512_8bit::GetQuantizerStd, AVX2_8bit::GetQuantizerStd, SSSE3_8bit::GetQuantizerStd, Unsupported_8bit::GetQuantizerStd, Unsupported_8bit::GetQuantizerStd); +MeanStd (*GetQuantizerStd)(const float *begin, const float *end) = ChooseCPU(avx512f::GetQuantizerStd, avx512f::GetQuantizerStd, avx2::GetQuantizerStd, sse2::GetQuantizerStd, sse2::GetQuantizerStd, sse2::GetQuantizerStd); constexpr const char *const Unsupported_16bit::kName; constexpr const char *const Unsupported_8bit::kName; @@ -117,10 +117,6 @@ struct Unsupported_8bit { throw UnsupportedCPU(); } - static MeanStd GetQuantizerStd(const float *, const float *) { - throw UnsupportedCPU(); - } - constexpr static const char *const kName = "8-bit Unsupported"; }; @@ -132,6 +128,9 @@ namespace avx512f { static inline float MaxAbsolute(const float * /*begin*/, const float * /*end*/) { throw UnsupportedCPU(); } +static inline MeanStd MaxAbsolute(const float * /*begin*/, const float * /*end*/) { + throw UnsupportedCPU(); +} } //namespace #endif @@ -275,9 +274,6 @@ struct Int8 { static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { MultiplyImpl<Callback>::run(A, B, A_rows, width, B_cols, callback); } - - // Get a Quantization value that is equant to the mean of the data +N standard deviations. Use 2 by default - static MeanStd (*GetQuantizerStd)(const float *begin, const float *end); static const char *const kName; @@ -418,5 +414,8 @@ extern const CPUType kCPU; // Get the maximum absolute value of an array of floats. The number of floats must be a multiple of 16 and 64-byte aligned. extern float (*MaxAbsolute)(const float *begin, const float *end); +// Get a Quantization value that is equant to the mean of the data +N standard deviations. Use 2 by default +extern MeanStd (*GetQuantizerStd)(const float *begin, const float *end); + } // namespace intgemm diff --git a/sse2_gemm.h b/sse2_gemm.h index 34de052..91221d9 100644 --- a/sse2_gemm.h +++ b/sse2_gemm.h @@ -53,6 +53,8 @@ class QuantizeTile16 { INTGEMM_MAXABSOLUTE(__m128, INTGEMM_SSE2) +INTGEMM_GETQUANTIZERSTD(__m128, INTGEMM_SSE2) + } //namespace // This should be pure INTGEMM_SSE2 (and below). struct SSE2_16bit { diff --git a/ssse3_gemm.h b/ssse3_gemm.h index df6144e..fd3ab8c 100644 --- a/ssse3_gemm.h +++ b/ssse3_gemm.h @@ -156,8 +156,6 @@ struct SSSE3_8bit { INTGEMM_PREPAREBIASFOR8(__m128i, INTGEMM_SSSE3, CPUType::SSE2) - INTGEMM_GETQUANTIZERSTD(__m128, INTGEMM_SSSE3) - constexpr static const char *const kName = "8-bit SSSE3"; static const CPUType kUses = CPUType::SSSE3; diff --git a/test/quantize_test.cc b/test/quantize_test.cc index 7b5c6a3..4e3d424 100644 --- a/test/quantize_test.cc +++ b/test/quantize_test.cc @@ -42,7 +42,7 @@ MeanStd QuantizerStddRef(AlignedVector<float>& vals, int num_items) { return ret; } -template <class Backend> +template <MeanStd (*Backend) (const float *, const float *)> void testQuantizerStd(int num_items) { std::mt19937 gen; std::uniform_real_distribution<float> dist(-1.0f, 1.0f); @@ -53,7 +53,7 @@ void testQuantizerStd(int num_items) { } MeanStd reference = QuantizerStddRef(inputVec, num_items); - MeanStd fast = Backend::GetQuantizerStd(inputVec.begin(), inputVec.end()); + MeanStd fast = Backend(inputVec.begin(), inputVec.end()); float meanDifference = fabs(reference.mean - fast.mean); float stdDifference = fabs(reference.stddev - fast.stddev); @@ -130,51 +130,51 @@ TEST_CASE ("Quantize AVX2", "[quantize]") { TEST_CASE("QuantizeStd SSSE3", "[quantizerSTD]") { if (kCPU < CPUType::SSSE3) return; - testQuantizerStd<SSSE3_8bit>(64); - testQuantizerStd<SSSE3_8bit>(64); - testQuantizerStd<SSSE3_8bit>(256); - testQuantizerStd<SSSE3_8bit>(256); - testQuantizerStd<SSSE3_8bit>(2048); - testQuantizerStd<SSSE3_8bit>(2048); - testQuantizerStd<SSSE3_8bit>(65536); - testQuantizerStd<SSSE3_8bit>(65536); - testQuantizerStd<SSSE3_8bit>(81920); - testQuantizerStd<SSSE3_8bit>(81920); - testQuantizerStd<SSSE3_8bit>(120832); - testQuantizerStd<SSSE3_8bit>(120832); + testQuantizerStd<sse2::GetQuantizerStd>(64); + testQuantizerStd<sse2::GetQuantizerStd>(64); + testQuantizerStd<sse2::GetQuantizerStd>(256); + testQuantizerStd<sse2::GetQuantizerStd>(256); + testQuantizerStd<sse2::GetQuantizerStd>(2048); + testQuantizerStd<sse2::GetQuantizerStd>(2048); + testQuantizerStd<sse2::GetQuantizerStd>(65536); + testQuantizerStd<sse2::GetQuantizerStd>(65536); + testQuantizerStd<sse2::GetQuantizerStd>(81920); + testQuantizerStd<sse2::GetQuantizerStd>(81920); + testQuantizerStd<sse2::GetQuantizerStd>(120832); + testQuantizerStd<sse2::GetQuantizerStd>(120832); } TEST_CASE("QuantizeStd AVX2", "[quantizerSTD]") { if (kCPU < CPUType::AVX2) return; - testQuantizerStd<AVX2_8bit>(64); - testQuantizerStd<AVX2_8bit>(64); - testQuantizerStd<AVX2_8bit>(256); - testQuantizerStd<AVX2_8bit>(256); - testQuantizerStd<AVX2_8bit>(2048); - testQuantizerStd<AVX2_8bit>(2048); - testQuantizerStd<AVX2_8bit>(65536); - testQuantizerStd<AVX2_8bit>(65536); - testQuantizerStd<AVX2_8bit>(81920); - testQuantizerStd<AVX2_8bit>(81920); - testQuantizerStd<AVX2_8bit>(120832); - testQuantizerStd<AVX2_8bit>(120832); + testQuantizerStd<avx2::GetQuantizerStd>(64); + testQuantizerStd<avx2::GetQuantizerStd>(64); + testQuantizerStd<avx2::GetQuantizerStd>(256); + testQuantizerStd<avx2::GetQuantizerStd>(256); + testQuantizerStd<avx2::GetQuantizerStd>(2048); + testQuantizerStd<avx2::GetQuantizerStd>(2048); + testQuantizerStd<avx2::GetQuantizerStd>(65536); + testQuantizerStd<avx2::GetQuantizerStd>(65536); + testQuantizerStd<avx2::GetQuantizerStd>(81920); + testQuantizerStd<avx2::GetQuantizerStd>(81920); + testQuantizerStd<avx2::GetQuantizerStd>(120832); + testQuantizerStd<avx2::GetQuantizerStd>(120832); } #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW TEST_CASE("QuantizeStd AVX512", "[quantizerSTD]") { if (kCPU < CPUType::AVX512BW) return; - testQuantizerStd<AVX512_8bit>(64); - testQuantizerStd<AVX512_8bit>(64); - testQuantizerStd<AVX512_8bit>(256); - testQuantizerStd<AVX512_8bit>(256); - testQuantizerStd<AVX512_8bit>(2048); - testQuantizerStd<AVX512_8bit>(2048); - testQuantizerStd<AVX512_8bit>(65536); - testQuantizerStd<AVX512_8bit>(65536); - testQuantizerStd<AVX512_8bit>(81920); - testQuantizerStd<AVX512_8bit>(81920); - testQuantizerStd<AVX512_8bit>(120832); - testQuantizerStd<AVX512_8bit>(120832); + testQuantizerStd<avx512f::GetQuantizerStd>(64); + testQuantizerStd<avx512f::GetQuantizerStd>(64); + testQuantizerStd<avx512f::GetQuantizerStd>(256); + testQuantizerStd<avx512f::GetQuantizerStd>(256); + testQuantizerStd<avx512f::GetQuantizerStd>(2048); + testQuantizerStd<avx512f::GetQuantizerStd>(2048); + testQuantizerStd<avx512f::GetQuantizerStd>(65536); + testQuantizerStd<avx512f::GetQuantizerStd>(65536); + testQuantizerStd<avx512f::GetQuantizerStd>(81920); + testQuantizerStd<avx512f::GetQuantizerStd>(81920); + testQuantizerStd<avx512f::GetQuantizerStd>(120832); + testQuantizerStd<avx512f::GetQuantizerStd>(120832); } #endif |