Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikolay Bogoychev <nheart@gmail.com>2020-03-26 13:08:57 +0300
committerNikolay Bogoychev <nheart@gmail.com>2020-03-26 13:08:57 +0300
commit6917f9d8f31c39e10d8c877c8307b1aee548ba0f (patch)
treea81a644fb8f1f8db024864b8a5d00863d4c633df
parent14bbbdbe83e46d6da7df29f8c7a7961039a58ec5 (diff)
Move QuantizerStd outside of the 8bit
-rw-r--r--avx2_gemm.h4
-rw-r--r--avx512_gemm.h4
-rw-r--r--intgemm.cc2
-rw-r--r--intgemm.h13
-rw-r--r--sse2_gemm.h2
-rw-r--r--ssse3_gemm.h2
-rw-r--r--test/quantize_test.cc76
7 files changed, 51 insertions, 52 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h
index 529d628..c1c4616 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -192,6 +192,8 @@ class QuantizeTile8 {
// Technically only requires AVX
INTGEMM_MAXABSOLUTE(__m256, INTGEMM_AVX2)
+INTGEMM_GETQUANTIZERSTD(__m256, INTGEMM_AVX2)
+
} // namespace
struct AVX2_8bit {
@@ -239,8 +241,6 @@ struct AVX2_8bit {
INTGEMM_MULTIPLY8SHIFT(__m256i, INTGEMM_AVX2, CPUType::AVX2)
INTGEMM_PREPAREBIASFOR8(__m256i, INTGEMM_AVX2, CPUType::AVX2)
-
- INTGEMM_GETQUANTIZERSTD(__m256, INTGEMM_AVX2)
constexpr static const char *const kName = "8-bit AVX2";
diff --git a/avx512_gemm.h b/avx512_gemm.h
index b8a561f..623e21a 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -159,6 +159,8 @@ class QuantizeTile8 {
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_MAXABSOLUTE(__m512, INTGEMM_AVX512BW)
+INTGEMM_GETQUANTIZERSTD(__m512, INTGEMM_AVX512BW)
+
} // namespace
struct AVX512_16bit {
@@ -433,8 +435,6 @@ struct AVX512_8bit {
INTGEMM_PREPAREBIASFOR8(__m512i, INTGEMM_AVX512BW, CPUType::AVX2)
- INTGEMM_GETQUANTIZERSTD(__m512, INTGEMM_AVX512BW)
-
constexpr static const char *const kName = "8-bit AVX512BW";
static const CPUType kUses = CPUType::AVX512BW;
diff --git a/intgemm.cc b/intgemm.cc
index 5014b47..095b38b 100644
--- a/intgemm.cc
+++ b/intgemm.cc
@@ -40,7 +40,7 @@ const CPUType kCPU = ChooseCPU(CPUType::AVX512VNNI, CPUType::AVX512BW, CPUType::
float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(avx512f::MaxAbsolute, avx512f::MaxAbsolute, avx2::MaxAbsolute, sse2::MaxAbsolute, sse2::MaxAbsolute, Unsupported_MaxAbsolute);
-MeanStd (*GetQuantizerStd)(const float *begin, const float *end) = ChooseCPU(AVX512VNNI_8bit::GetQuantizerStd, AVX512_8bit::GetQuantizerStd, AVX2_8bit::GetQuantizerStd, SSSE3_8bit::GetQuantizerStd, Unsupported_8bit::GetQuantizerStd, Unsupported_8bit::GetQuantizerStd);
+MeanStd (*GetQuantizerStd)(const float *begin, const float *end) = ChooseCPU(avx512f::GetQuantizerStd, avx512f::GetQuantizerStd, avx2::GetQuantizerStd, sse2::GetQuantizerStd, sse2::GetQuantizerStd, sse2::GetQuantizerStd);
constexpr const char *const Unsupported_16bit::kName;
constexpr const char *const Unsupported_8bit::kName;
diff --git a/intgemm.h b/intgemm.h
index ac877e8..8c5309b 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -117,10 +117,6 @@ struct Unsupported_8bit {
throw UnsupportedCPU();
}
- static MeanStd GetQuantizerStd(const float *, const float *) {
- throw UnsupportedCPU();
- }
-
constexpr static const char *const kName = "8-bit Unsupported";
};
@@ -132,6 +128,9 @@ namespace avx512f {
static inline float MaxAbsolute(const float * /*begin*/, const float * /*end*/) {
throw UnsupportedCPU();
}
+static inline MeanStd MaxAbsolute(const float * /*begin*/, const float * /*end*/) {
+ throw UnsupportedCPU();
+}
} //namespace
#endif
@@ -275,9 +274,6 @@ struct Int8 {
static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
MultiplyImpl<Callback>::run(A, B, A_rows, width, B_cols, callback);
}
-
- // Get a Quantization value that is equant to the mean of the data +N standard deviations. Use 2 by default
- static MeanStd (*GetQuantizerStd)(const float *begin, const float *end);
static const char *const kName;
@@ -418,5 +414,8 @@ extern const CPUType kCPU;
// Get the maximum absolute value of an array of floats. The number of floats must be a multiple of 16 and 64-byte aligned.
extern float (*MaxAbsolute)(const float *begin, const float *end);
+// Get a Quantization value that is equant to the mean of the data +N standard deviations. Use 2 by default
+extern MeanStd (*GetQuantizerStd)(const float *begin, const float *end);
+
} // namespace intgemm
diff --git a/sse2_gemm.h b/sse2_gemm.h
index 34de052..91221d9 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -53,6 +53,8 @@ class QuantizeTile16 {
INTGEMM_MAXABSOLUTE(__m128, INTGEMM_SSE2)
+INTGEMM_GETQUANTIZERSTD(__m128, INTGEMM_SSE2)
+
} //namespace
// This should be pure INTGEMM_SSE2 (and below).
struct SSE2_16bit {
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index df6144e..fd3ab8c 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -156,8 +156,6 @@ struct SSSE3_8bit {
INTGEMM_PREPAREBIASFOR8(__m128i, INTGEMM_SSSE3, CPUType::SSE2)
- INTGEMM_GETQUANTIZERSTD(__m128, INTGEMM_SSSE3)
-
constexpr static const char *const kName = "8-bit SSSE3";
static const CPUType kUses = CPUType::SSSE3;
diff --git a/test/quantize_test.cc b/test/quantize_test.cc
index 7b5c6a3..4e3d424 100644
--- a/test/quantize_test.cc
+++ b/test/quantize_test.cc
@@ -42,7 +42,7 @@ MeanStd QuantizerStddRef(AlignedVector<float>& vals, int num_items) {
return ret;
}
-template <class Backend>
+template <MeanStd (*Backend) (const float *, const float *)>
void testQuantizerStd(int num_items) {
std::mt19937 gen;
std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
@@ -53,7 +53,7 @@ void testQuantizerStd(int num_items) {
}
MeanStd reference = QuantizerStddRef(inputVec, num_items);
- MeanStd fast = Backend::GetQuantizerStd(inputVec.begin(), inputVec.end());
+ MeanStd fast = Backend(inputVec.begin(), inputVec.end());
float meanDifference = fabs(reference.mean - fast.mean);
float stdDifference = fabs(reference.stddev - fast.stddev);
@@ -130,51 +130,51 @@ TEST_CASE ("Quantize AVX2", "[quantize]") {
TEST_CASE("QuantizeStd SSSE3", "[quantizerSTD]") {
if (kCPU < CPUType::SSSE3) return;
- testQuantizerStd<SSSE3_8bit>(64);
- testQuantizerStd<SSSE3_8bit>(64);
- testQuantizerStd<SSSE3_8bit>(256);
- testQuantizerStd<SSSE3_8bit>(256);
- testQuantizerStd<SSSE3_8bit>(2048);
- testQuantizerStd<SSSE3_8bit>(2048);
- testQuantizerStd<SSSE3_8bit>(65536);
- testQuantizerStd<SSSE3_8bit>(65536);
- testQuantizerStd<SSSE3_8bit>(81920);
- testQuantizerStd<SSSE3_8bit>(81920);
- testQuantizerStd<SSSE3_8bit>(120832);
- testQuantizerStd<SSSE3_8bit>(120832);
+ testQuantizerStd<sse2::GetQuantizerStd>(64);
+ testQuantizerStd<sse2::GetQuantizerStd>(64);
+ testQuantizerStd<sse2::GetQuantizerStd>(256);
+ testQuantizerStd<sse2::GetQuantizerStd>(256);
+ testQuantizerStd<sse2::GetQuantizerStd>(2048);
+ testQuantizerStd<sse2::GetQuantizerStd>(2048);
+ testQuantizerStd<sse2::GetQuantizerStd>(65536);
+ testQuantizerStd<sse2::GetQuantizerStd>(65536);
+ testQuantizerStd<sse2::GetQuantizerStd>(81920);
+ testQuantizerStd<sse2::GetQuantizerStd>(81920);
+ testQuantizerStd<sse2::GetQuantizerStd>(120832);
+ testQuantizerStd<sse2::GetQuantizerStd>(120832);
}
TEST_CASE("QuantizeStd AVX2", "[quantizerSTD]") {
if (kCPU < CPUType::AVX2) return;
- testQuantizerStd<AVX2_8bit>(64);
- testQuantizerStd<AVX2_8bit>(64);
- testQuantizerStd<AVX2_8bit>(256);
- testQuantizerStd<AVX2_8bit>(256);
- testQuantizerStd<AVX2_8bit>(2048);
- testQuantizerStd<AVX2_8bit>(2048);
- testQuantizerStd<AVX2_8bit>(65536);
- testQuantizerStd<AVX2_8bit>(65536);
- testQuantizerStd<AVX2_8bit>(81920);
- testQuantizerStd<AVX2_8bit>(81920);
- testQuantizerStd<AVX2_8bit>(120832);
- testQuantizerStd<AVX2_8bit>(120832);
+ testQuantizerStd<avx2::GetQuantizerStd>(64);
+ testQuantizerStd<avx2::GetQuantizerStd>(64);
+ testQuantizerStd<avx2::GetQuantizerStd>(256);
+ testQuantizerStd<avx2::GetQuantizerStd>(256);
+ testQuantizerStd<avx2::GetQuantizerStd>(2048);
+ testQuantizerStd<avx2::GetQuantizerStd>(2048);
+ testQuantizerStd<avx2::GetQuantizerStd>(65536);
+ testQuantizerStd<avx2::GetQuantizerStd>(65536);
+ testQuantizerStd<avx2::GetQuantizerStd>(81920);
+ testQuantizerStd<avx2::GetQuantizerStd>(81920);
+ testQuantizerStd<avx2::GetQuantizerStd>(120832);
+ testQuantizerStd<avx2::GetQuantizerStd>(120832);
}
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE("QuantizeStd AVX512", "[quantizerSTD]") {
if (kCPU < CPUType::AVX512BW) return;
- testQuantizerStd<AVX512_8bit>(64);
- testQuantizerStd<AVX512_8bit>(64);
- testQuantizerStd<AVX512_8bit>(256);
- testQuantizerStd<AVX512_8bit>(256);
- testQuantizerStd<AVX512_8bit>(2048);
- testQuantizerStd<AVX512_8bit>(2048);
- testQuantizerStd<AVX512_8bit>(65536);
- testQuantizerStd<AVX512_8bit>(65536);
- testQuantizerStd<AVX512_8bit>(81920);
- testQuantizerStd<AVX512_8bit>(81920);
- testQuantizerStd<AVX512_8bit>(120832);
- testQuantizerStd<AVX512_8bit>(120832);
+ testQuantizerStd<avx512f::GetQuantizerStd>(64);
+ testQuantizerStd<avx512f::GetQuantizerStd>(64);
+ testQuantizerStd<avx512f::GetQuantizerStd>(256);
+ testQuantizerStd<avx512f::GetQuantizerStd>(256);
+ testQuantizerStd<avx512f::GetQuantizerStd>(2048);
+ testQuantizerStd<avx512f::GetQuantizerStd>(2048);
+ testQuantizerStd<avx512f::GetQuantizerStd>(65536);
+ testQuantizerStd<avx512f::GetQuantizerStd>(65536);
+ testQuantizerStd<avx512f::GetQuantizerStd>(81920);
+ testQuantizerStd<avx512f::GetQuantizerStd>(81920);
+ testQuantizerStd<avx512f::GetQuantizerStd>(120832);
+ testQuantizerStd<avx512f::GetQuantizerStd>(120832);
}
#endif