diff options
-rw-r--r-- | avx2_gemm.h | 9 | ||||
-rw-r--r-- | avx512_gemm.h | 17 | ||||
-rw-r--r-- | benchmark.cc | 22 | ||||
-rw-r--r-- | example.cc | 4 | ||||
-rw-r--r-- | intgemm.h | 60 | ||||
-rw-r--r-- | multiply.h | 19 | ||||
-rw-r--r-- | sse2_gemm.h | 5 | ||||
-rw-r--r-- | ssse3_gemm.h | 5 | ||||
-rw-r--r-- | test/multiply_test.cc | 76 |
9 files changed, 109 insertions, 108 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h index c560b49..d278e3e 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -82,10 +82,12 @@ struct AVX2_16bit { AVX2 static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end); } - +/* AVX2 static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { Multiply16__m256i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); } + */ + MULTIPLY16_define(__m256i, AVX2) constexpr static const char *const kName = "16-bit AVX2"; @@ -174,11 +176,12 @@ struct AVX2_8bit { AVX2 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end); } - +/* AVX2 static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { //Multiply8_SSE2OrAVX2<Multiply8_AVXAVX2, __m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols); Multiply8_SSE2OrAVX2__m256i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); - } + }*/ + MULTIPLY8_define(__m256i, AVX2) constexpr static const char *const kName = "8-bit AVX2"; diff --git a/avx512_gemm.h b/avx512_gemm.h index 93a5997..97d2e73 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -155,11 +155,12 @@ struct AVX512_16bit { AVX512F static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { avx512f::SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end); } - +/* AVX512F static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { // The unquantization is only 256-bit wide because there are 8 results. Multiply16__m512i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); - } + }*/ + MULTIPLY16_define(__m512i, AVX512F) constexpr static const char *const kName = "16-bit AVX512"; @@ -206,16 +207,17 @@ struct AVX512_8bit { // Special AVX512 implementation due to having 32 registers (so I don't have to // allocate registers manually) and no sign instruction. - AVX512BW static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { + template <class WriteC> + AVX512BW static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { typedef __m512i Integer; - typedef __m256 Float; // For quantization we only do 8 at a time. + //typedef __m256 Float; // For quantization we only do 8 at a time. // This is copy-paste from Multiply8_SSE2OrAVX2. assert(width % sizeof(Integer) == 0); assert(B_cols % 8 == 0); assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); - assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0); - Float unquant_reg = set1_ps<Float>(unquant_mult); + //assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0); + //Float unquant_reg = set1_ps<Float>(unquant_mult); const int simd_width = width / sizeof(Integer); const Integer *B0_col = reinterpret_cast<const Integer*>(B); // Added for AVX512. @@ -312,7 +314,8 @@ struct AVX512_8bit { Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); - WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg); + //WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg); + functor(A_rowidx, B_cols, B0_colidx, total); } } } diff --git a/benchmark.cc b/benchmark.cc index a629049..965fd68 100644 --- a/benchmark.cc +++ b/benchmark.cc @@ -62,7 +62,7 @@ struct RandomMatrices { AlignedVector<float> A, B; }; -template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t> &stats) { +template <class Backend, class WriteC> void Run(const RandomMatrices &m, std::vector<uint64_t> &stats) { typedef typename Backend::Integer Integer; float quant_mult = 127.0 / 2; float unquant_mult = 1.0 / (quant_mult * quant_mult); @@ -72,20 +72,20 @@ template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t> Backend::PrepareB(m.B.get(), B_prepared.get(), quant_mult, m.width, m.B_cols); AlignedVector<float> output(m.A_rows * m.B_cols); // Burn in - Backend::Multiply(A_prepared.get(), B_prepared.get(), output.get(), unquant_mult, m.A_rows, m.width, m.B_cols); + Backend::template Multiply<WriteC>(A_prepared.get(), B_prepared.get(), JustUnquantizeC(output.get(), unquant_mult), m.A_rows, m.width, m.B_cols); { StopWatch w(stats); - Backend::Multiply(A_prepared.get(), B_prepared.get(), output.get(), unquant_mult, m.A_rows, m.width, m.B_cols); + Backend::template Multiply<WriteC>(A_prepared.get(), B_prepared.get(), JustUnquantizeC(output.get(), unquant_mult), m.A_rows, m.width, m.B_cols); } } -template <class Backend> void RunAll(RandomMatrices *matrices, RandomMatrices *matrices_end, std::vector<std::vector<uint64_t> > &stats) { +template <class Backend, class WriteC> void RunAll(RandomMatrices *matrices, RandomMatrices *matrices_end, std::vector<std::vector<uint64_t> > &stats) { if (Backend::kUses > kCPU) return; std::size_t size = matrices_end - matrices; if (stats.size() < size) stats.resize(size); for (std::size_t i = 0; i < size; ++i) { - Run<Backend>(matrices[i], stats[i]); + Run<Backend, WriteC>(matrices[i], stats[i]); } } @@ -169,13 +169,13 @@ int main(int argc, char ** argv) { for (int samples = 0; samples < kSamples; ++samples) { std::cerr << "Sample " << samples << " / " << kSamples << std::endl; RandomMatrices *end = (samples < 4) ? matrices_end : full_sample; - RunAll<SSSE3_8bit>(matrices, end, stats.ssse3_8bit); - RunAll<SSE2_16bit>(matrices, end, stats.sse2_16bit); - RunAll<AVX2_8bit>(matrices, end, stats.avx2_8bit); - RunAll<AVX2_16bit>(matrices, end, stats.avx2_16bit); + RunAll<SSSE3_8bit, JustUnquantizeC>(matrices, end, stats.ssse3_8bit); + RunAll<SSE2_16bit, JustUnquantizeC>(matrices, end, stats.sse2_16bit); + RunAll<AVX2_8bit, JustUnquantizeC>(matrices, end, stats.avx2_8bit); + RunAll<AVX2_16bit, JustUnquantizeC>(matrices, end, stats.avx2_16bit); #ifndef INTGEMM_NO_AVX512 - RunAll<AVX512_8bit>(matrices, end, stats.avx512_8bit); - RunAll<AVX512_16bit>(matrices, end, stats.avx512_16bit); + RunAll<AVX512_8bit, JustUnquantizeC>(matrices, end, stats.avx512_8bit); + RunAll<AVX512_16bit, JustUnquantizeC>(matrices, end, stats.avx512_16bit); #endif } @@ -49,7 +49,7 @@ int main() { AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. - intgemm::Int16<intgemm::JustUnquantizeC>::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); + intgemm::Int16<intgemm::JustUnquantizeC>::Multiply(A_prepared.get(), B_prepared.get(), intgemm::JustUnquantizeC(C.get(), 1.0 / (quant_mult * quant_mult)), A_rows, width, B_cols); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } @@ -68,7 +68,7 @@ int main() { AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. - intgemm::Int8<intgemm::JustUnquantizeC>::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); + intgemm::Int8<intgemm::JustUnquantizeC>::Multiply(A_prepared.get(), B_prepared.get(), intgemm::JustUnquantizeC(C.get(), 1.0 / (quant_mult * quant_mult)), A_rows, width, B_cols); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } @@ -56,7 +56,9 @@ namespace intgemm { -struct Unsupported_16bit { +template<class WriteC> +class Unsupported_16bit { +public: static void Quantize(const float *, int16_t *, float, Index) { throw UnsupportedCPU(); } @@ -66,13 +68,15 @@ struct Unsupported_16bit { static void SelectColumnsB(const int16_t *, int16_t *, Index, const Index *, const Index *) { throw UnsupportedCPU(); } - static void Multiply(const int16_t *, const int16_t *, float *, float, Index, Index, Index) { + static void Multiply(const int16_t *, const int16_t *, WriteC, Index, Index, Index) { throw UnsupportedCPU(); } constexpr static const char *const kName = "16-bit Unsupported"; }; -struct Unsupported_8bit { +template<class WriteC> +class Unsupported_8bit { +public: static void Quantize(const float *, int8_t *, float, Index) { throw UnsupportedCPU(); } @@ -82,7 +86,7 @@ struct Unsupported_8bit { static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) { throw UnsupportedCPU(); } - static void Multiply(const int8_t *, const int8_t *, float *, float, Index, Index, Index) { + static void Multiply(const int8_t *, const int8_t *, WriteC, Index, Index, Index) { throw UnsupportedCPU(); } constexpr static const char *const kName = "8-bit Unsupported"; @@ -132,7 +136,7 @@ template <class T> T ChooseCPU(T avx512, T avx2, T ssse3, T sse2, T unsupported) } /* 16-bit matrix multiplication. */ -template<class cOperator> +template<class WriteC> class Int16 { public: typedef int16_t Integer; @@ -163,28 +167,28 @@ public: static void (*SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end); // Multiply C = A * B, presuming A and B have been prepared. - static void (*Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols); + static void (*Multiply)(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols); static const char *const kName; }; -template <class cOperator> -void (*Int16<cOperator>::Quantize)(const float *input, int16_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit::Quantize); +template <class WriteC> +void (*Int16<WriteC>::Quantize)(const float *input, int16_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit<WriteC>::Quantize); -template <class cOperator> -void (*Int16<cOperator>::PrepareB)(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit::PrepareB); +template <class WriteC> +void (*Int16<WriteC>::PrepareB)(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit<WriteC>::PrepareB); -template <class cOperator> -void (*Int16<cOperator>::SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_16bit::SelectColumnsB, AVX2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, Unsupported_16bit::SelectColumnsB); +template <class WriteC> +void (*Int16<WriteC>::SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_16bit::SelectColumnsB, AVX2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, Unsupported_16bit<WriteC>::SelectColumnsB); -template <class cOperator> -void (*Int16<cOperator>::Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply, AVX2_16bit::Multiply, SSE2_16bit::Multiply, SSE2_16bit::Multiply, Unsupported_16bit::Multiply); +template <class WriteC> +void (*Int16<WriteC>::Multiply)(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply<WriteC>, AVX2_16bit::Multiply<WriteC>, SSE2_16bit::Multiply<WriteC>, SSE2_16bit::Multiply<WriteC>, Unsupported_16bit<WriteC>::Multiply); -template <class cOperator> -const char *const Int16<cOperator>::kName = ChooseCPU(AVX512_16bit::kName, AVX2_16bit::kName, SSE2_16bit::kName, SSE2_16bit::kName, Unsupported_16bit::kName); +template <class WriteC> +const char *const Int16<WriteC>::kName = ChooseCPU(AVX512_16bit::kName, AVX2_16bit::kName, SSE2_16bit::kName, SSE2_16bit::kName, Unsupported_16bit<WriteC>::kName); /* 8-bit matrix multiplication */ -template<class cOperator> +template<class WriteC> class Int8 { public: typedef int8_t Integer; @@ -214,25 +218,25 @@ public: static void (*SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end); // Multiply C = A * B, presuming A and B have been prepared. - static void (*Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols); + static void (*Multiply)(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols); static const char *const kName; }; -template<class cOperator> -void (*Int8<cOperator>::Quantize)(const float *input, int8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize); +template<class WriteC> +void (*Int8<WriteC>::Quantize)(const float *input, int8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit<WriteC>::Quantize, Unsupported_8bit<WriteC>::Quantize); -template<class cOperator> -void (*Int8<cOperator>::PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB); +template<class WriteC> +void (*Int8<WriteC>::PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit<WriteC>::PrepareB, Unsupported_8bit<WriteC>::PrepareB); -template<class cOperator> -void (*Int8<cOperator>::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_8bit::SelectColumnsB, AVX2_8bit::SelectColumnsB, SSSE3_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB); +template<class WriteC> +void (*Int8<WriteC>::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_8bit::SelectColumnsB, AVX2_8bit::SelectColumnsB, SSSE3_8bit::SelectColumnsB, Unsupported_8bit<WriteC>::SelectColumnsB, Unsupported_8bit<WriteC>::SelectColumnsB); -template<class cOperator> -void (*Int8<cOperator>::Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply, AVX2_8bit::Multiply, SSSE3_8bit::Multiply, Unsupported_8bit::Multiply, Unsupported_8bit::Multiply); +template<class WriteC> +void (*Int8<WriteC>::Multiply)(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply<WriteC>, AVX2_8bit::Multiply<WriteC>, SSSE3_8bit::Multiply<WriteC>, Unsupported_8bit<WriteC>::Multiply, Unsupported_8bit<WriteC>::Multiply); -template<class cOperator> -const char *const Int8<cOperator>::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit::kName, Unsupported_8bit::kName); +template<class WriteC> +const char *const Int8<WriteC>::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit<WriteC>::kName, Unsupported_8bit<WriteC>::kName); const CPUType kCPU = ChooseCPU(CPU_AVX512BW, CPU_AVX2, CPU_SSSE3, CPU_SSE2, CPU_UNSUPPORTED); @@ -134,9 +134,9 @@ PACK_DEFINE(AVX512F, __m512i) // A_rows can be anything non-negative. // width must be a multiple of the register size. // B_cols must be a multiple of 8. -//template <class Integer, class WriteC> inline void Multiply16(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { +// Multiply16 #define MULTIPLY16_define(Integer, target) \ - template <class WriteC> target inline void Multiply16##Integer(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \ + template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \ assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ @@ -192,13 +192,6 @@ PACK_DEFINE(AVX512F, __m512i) } \ } \ -MULTIPLY16_define(__m128i, SSE2) - -MULTIPLY16_define(__m256i, AVX2) -#ifndef INTGEMM_NO_AVX512 -MULTIPLY16_define(__m512i, AVX512F) -#endif -//MULTIPLY16_define(__m256i, AVX2) /* 8-bit matrix multiply used by AVX and AVX2. * These have two peculiar properties: * 1. The sign instructions don't exist in AVX512. @@ -346,9 +339,9 @@ SSSE3 inline static void InnerSSSE3( sum6 = adds_epi16(sum6, maddubs_epi16(a_positive, sign_epi8(b[6], a))); sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a))); } - +//AVX2 or SSSE3 multiply #define MULTIPLY8_define(Integer, target) \ -template <class WriteC> target inline void Multiply8_SSE2OrAVX2##Integer(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \ +template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \ assert(width % sizeof(Integer) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ @@ -415,10 +408,6 @@ template <class WriteC> target inline void Multiply8_SSE2OrAVX2##Integer(const i } \ } \ -MULTIPLY8_define(__m128i, SSSE3) - -MULTIPLY8_define(__m256i, AVX2) - // Find the maximum absolute value of packed float32s. /* diff --git a/sse2_gemm.h b/sse2_gemm.h index a7d60ae..3ede6d8 100644 --- a/sse2_gemm.h +++ b/sse2_gemm.h @@ -73,11 +73,12 @@ struct SSE2_16bit { //TODO #DEFINE sse2::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end); } - +/* SSE2 static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { //TODO #DEFINE Multiply16__m128i<JustUnquantizeC> (A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); - } + }*/ + MULTIPLY16_define(__m128i, SSE2) constexpr static const char *const kName = "16-bit SSE2"; diff --git a/ssse3_gemm.h b/ssse3_gemm.h index bb2df94..d384500 100644 --- a/ssse3_gemm.h +++ b/ssse3_gemm.h @@ -97,11 +97,12 @@ struct SSSE3_8bit { SSSE3 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { ssse3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end); } - +/* SSSE3 static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { //Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols); Multiply8_SSE2OrAVX2__m128i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); - } + }*/ + MULTIPLY8_define(__m128i, SSSE3) constexpr static const char *const kName = "8-bit SSSE3"; diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 26ca9be..ee5e2c6 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -331,7 +331,7 @@ void Compare(const float *float_ref, const float *int_ref, const float *int_test CHECK_MESSAGE(fabs(sqrt(int_sum / size)) <= MSE_int_tolerance, test_info << "Int MSE = " << sqrt(int_sum / size)); } -template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_cols, +template <class Routine, class WriteC> void TestMultiply(Index A_rows, Index width, Index B_cols, float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { typedef typename Routine::Integer Integer; std::ostringstream info; @@ -357,7 +357,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co Routine::PrepareB(B.get(), B_prep.get(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); - Routine::Multiply(A_prep.get(), B_prep.get(), test_C.get(), unquant_mult, A_rows, width, B_cols); + Routine::template Multiply<WriteC>(A_prep.get(), B_prep.get(), WriteC(test_C.get(), unquant_mult), A_rows, width, B_cols); AlignedVector<Integer> B_quant(width * B_cols); Routine::Quantize(B.get(), B_quant.get(), quant_mult, width * B_cols); @@ -374,63 +374,63 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co TEST_CASE ("Multiply SSE2 16bit", "[multiply]") { if (kCPU < CPU_SSE2) return; - TestMultiply<SSE2_16bit>(8, 256, 256, .1, 1, 0.01); - TestMultiply<SSE2_16bit>(8, 2048, 256, .1, 1, 0.02); - TestMultiply<SSE2_16bit>(320, 256, 256, .1, 1, 0.01); - TestMultiply<SSE2_16bit>(472, 256, 256, .1, 1, 0.01); - TestMultiply<SSE2_16bit>(248, 256, 256, .1, 1, 0.01); - TestMultiply<SSE2_16bit>(200, 256, 256, .1, 1, 0.01); + TestMultiply<SSE2_16bit, JustUnquantizeC>(8, 256, 256, .1, 1, 0.01); + TestMultiply<SSE2_16bit, JustUnquantizeC>(8, 2048, 256, .1, 1, 0.02); + TestMultiply<SSE2_16bit, JustUnquantizeC>(320, 256, 256, .1, 1, 0.01); + TestMultiply<SSE2_16bit, JustUnquantizeC>(472, 256, 256, .1, 1, 0.01); + TestMultiply<SSE2_16bit, JustUnquantizeC>(248, 256, 256, .1, 1, 0.01); + TestMultiply<SSE2_16bit, JustUnquantizeC>(200, 256, 256, .1, 1, 0.01); } TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") { if (kCPU < CPU_SSSE3) return; - TestMultiply<SSSE3_8bit>(8, 256, 256, 1.2, 1.2, 0.064, 0.026); - TestMultiply<SSSE3_8bit>(8, 2048, 256, 33, 33, 4.4, 4.4); - TestMultiply<SSSE3_8bit>(320, 256, 256, 1.9, 1.9, 0.1, 0.01); - TestMultiply<SSSE3_8bit>(472, 256, 256, 2.1, 2.1, 0.1, 0.011); - TestMultiply<SSSE3_8bit>(248, 256, 256, 1.7, 1.7, 0.1, 0.012); - TestMultiply<SSSE3_8bit>(200, 256, 256, 1.8, 1.9, 0.1, 0.011); + TestMultiply<SSSE3_8bit, JustUnquantizeC>(8, 256, 256, 1.2, 1.2, 0.064, 0.026); + TestMultiply<SSSE3_8bit, JustUnquantizeC>(8, 2048, 256, 33, 33, 4.4, 4.4); + TestMultiply<SSSE3_8bit, JustUnquantizeC>(320, 256, 256, 1.9, 1.9, 0.1, 0.01); + TestMultiply<SSSE3_8bit, JustUnquantizeC>(472, 256, 256, 2.1, 2.1, 0.1, 0.011); + TestMultiply<SSSE3_8bit, JustUnquantizeC>(248, 256, 256, 1.7, 1.7, 0.1, 0.012); + TestMultiply<SSSE3_8bit, JustUnquantizeC>(200, 256, 256, 1.8, 1.9, 0.1, 0.011); } TEST_CASE ("Multiply AVX2 8bit", "[multiply]") { if (kCPU < CPU_AVX2) return; - TestMultiply<AVX2_8bit>(8, 256, 256, .1, 1, 0.1); - TestMultiply<AVX2_8bit>(8, 2048, 256, 19, 19, 1.8, 1.8); - TestMultiply<AVX2_8bit>(320, 256, 256, .1, 1, 0.1); - TestMultiply<AVX2_8bit>(472, 256, 256, .1, 1, 0.1); - TestMultiply<AVX2_8bit>(248, 256, 256, .1, 1, 0.1); - TestMultiply<AVX2_8bit>(200, 256, 256, .1, 1, 0.1); + TestMultiply<AVX2_8bit, JustUnquantizeC>(8, 256, 256, .1, 1, 0.1); + TestMultiply<AVX2_8bit, JustUnquantizeC>(8, 2048, 256, 19, 19, 1.8, 1.8); + TestMultiply<AVX2_8bit, JustUnquantizeC>(320, 256, 256, .1, 1, 0.1); + TestMultiply<AVX2_8bit, JustUnquantizeC>(472, 256, 256, .1, 1, 0.1); + TestMultiply<AVX2_8bit, JustUnquantizeC>(248, 256, 256, .1, 1, 0.1); + TestMultiply<AVX2_8bit, JustUnquantizeC>(200, 256, 256, .1, 1, 0.1); } TEST_CASE ("Multiply AVX2 16bit", "[multiply]") { if (kCPU < CPU_AVX2) return; - TestMultiply<AVX2_16bit>(8, 256, 256, .1, 1, 0.01); - TestMultiply<AVX2_16bit>(8, 2048, 256, .1, 1, 0.02); - TestMultiply<AVX2_16bit>(320, 256, 256, .1, 1, 0.01); - TestMultiply<AVX2_16bit>(472, 256, 256, .1, 1, 0.01); - TestMultiply<AVX2_16bit>(248, 256, 256, .1, 1, 0.01); - TestMultiply<AVX2_16bit>(200, 256, 256, .1, 1, 0.01); + TestMultiply<AVX2_16bit, JustUnquantizeC>(8, 256, 256, .1, 1, 0.01); + TestMultiply<AVX2_16bit, JustUnquantizeC>(8, 2048, 256, .1, 1, 0.02); + TestMultiply<AVX2_16bit, JustUnquantizeC>(320, 256, 256, .1, 1, 0.01); + TestMultiply<AVX2_16bit, JustUnquantizeC>(472, 256, 256, .1, 1, 0.01); + TestMultiply<AVX2_16bit, JustUnquantizeC>(248, 256, 256, .1, 1, 0.01); + TestMultiply<AVX2_16bit, JustUnquantizeC>(200, 256, 256, .1, 1, 0.01); } #ifndef INTGEMM_NO_AVX512 TEST_CASE ("Multiply AVX512 8bit", "[multiply]") { if (kCPU < CPU_AVX512BW) return; - TestMultiply<AVX512_8bit>(8, 256, 256, .1, 1, 0.062); - TestMultiply<AVX512_8bit>(8, 2048, 256, 4.2, 4, 0.41, 0.37); - TestMultiply<AVX512_8bit>(320, 256, 256, .1, 1, 0.06); - TestMultiply<AVX512_8bit>(472, 256, 256, .1, 1, 0.06); - TestMultiply<AVX512_8bit>(248, 256, 256, .1, 1, 0.06); - TestMultiply<AVX512_8bit>(200, 256, 256, .1, 1, 0.06); + TestMultiply<AVX512_8bit, JustUnquantizeC>(8, 256, 256, .1, 1, 0.062); + TestMultiply<AVX512_8bit, JustUnquantizeC>(8, 2048, 256, 4.2, 4, 0.41, 0.37); + TestMultiply<AVX512_8bit, JustUnquantizeC>(320, 256, 256, .1, 1, 0.06); + TestMultiply<AVX512_8bit, JustUnquantizeC>(472, 256, 256, .1, 1, 0.06); + TestMultiply<AVX512_8bit, JustUnquantizeC>(248, 256, 256, .1, 1, 0.06); + TestMultiply<AVX512_8bit, JustUnquantizeC>(200, 256, 256, .1, 1, 0.06); } TEST_CASE ("Multiply AVX512 16bit", "[multiply]") { if (kCPU < CPU_AVX512BW) return; - TestMultiply<AVX512_16bit>(8, 256, 256, .1, 1, 0.01); - TestMultiply<AVX512_16bit>(8, 2048, 256, .1, 1, 0.011); - TestMultiply<AVX512_16bit>(320, 256, 256, .1, 1, 0.01); - TestMultiply<AVX512_16bit>(472, 256, 256, .1, 1, 0.01); - TestMultiply<AVX512_16bit>(248, 256, 256, .1, 1, 0.01); - TestMultiply<AVX512_16bit>(200, 256, 256, .1, 1, 0.01); + TestMultiply<AVX512_16bit, JustUnquantizeC>(8, 256, 256, .1, 1, 0.01); + TestMultiply<AVX512_16bit, JustUnquantizeC>(8, 2048, 256, .1, 1, 0.011); + TestMultiply<AVX512_16bit, JustUnquantizeC>(320, 256, 256, .1, 1, 0.01); + TestMultiply<AVX512_16bit, JustUnquantizeC>(472, 256, 256, .1, 1, 0.01); + TestMultiply<AVX512_16bit, JustUnquantizeC>(248, 256, 256, .1, 1, 0.01); + TestMultiply<AVX512_16bit, JustUnquantizeC>(200, 256, 256, .1, 1, 0.01); } #endif |