diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-05-24 15:45:11 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-06-18 16:38:53 +0300 |
commit | f0785bea3b42a8e5ab7e322b5ad0dc1e9018d65f (patch) | |
tree | acee1462bbf5fd31ec58a1988c2d4f03009eefcf | |
parent | 5239e8820c3afa68abd679bfe7cad7e7b9b9893e (diff) |
Add support for postprocess pipeline
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | avx2_gemm.h | 15 | ||||
-rw-r--r-- | avx512_gemm.h | 11 | ||||
-rw-r--r-- | benchmark.cc | 4 | ||||
-rw-r--r-- | cops.h | 212 | ||||
-rw-r--r-- | example.cc | 4 | ||||
-rw-r--r-- | intgemm.h | 38 | ||||
-rw-r--r-- | multiply.h | 43 | ||||
-rw-r--r-- | postprocess.h | 221 | ||||
-rw-r--r-- | postprocess_pipeline.h | 140 | ||||
-rw-r--r-- | sse2_gemm.h | 2 | ||||
-rw-r--r-- | ssse3_gemm.h | 13 | ||||
-rw-r--r-- | test/multiply_test.cc | 7 | ||||
-rw-r--r-- | test/pipeline_test.cc | 32 | ||||
-rw-r--r-- | test/relu_test.cc | 133 | ||||
-rw-r--r-- | vec_utils.h | 6 |
17 files changed, 532 insertions, 352 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index dc89b83..0d42c23 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,7 @@ endforeach() include_directories(.) add_executable(tests test/multiply_test.cc + test/pipeline_test.cc test/quantize_test.cc test/relu_test.cc intgemm.cc @@ -32,7 +32,7 @@ intgemm::Int16::PrepareA(A, A_prepared, quant_mult, A_rows, width); /* Prepare B for multiplication. This is typically done offline. */ intgemm::Int16::PrepareB(B, B_prepared, quant_mult, width, B_cols); /* Multiply and produce results in C */ -intgemm::Int16::Multiply<intgemm::JustUnquantizeC>(A_prepared.begin(), B_prepared.begin(), intgemm::JustUnquantizeC(C.begin(), 1.0 / (quant_mult * quant_mult)), A_rows, width, B_cols); +intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), C.begin(), intgemm::CreatePostprocessPipeline(intgemm::Unquantize(1.0 / (quant_mult * quant_mult))), A_rows, width, B_cols); ``` For 8-bit, use `Int8` instead of `Int16`. diff --git a/avx2_gemm.h b/avx2_gemm.h index c5ca0bc..a03ff09 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -80,7 +80,7 @@ struct AVX2_16bit { avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end); } - INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, OnAVX2) + INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, CPUType::CPU_AVX2) constexpr static const char *const kName = "16-bit INTGEMM_AVX2"; @@ -163,22 +163,13 @@ struct AVX2_8bit { static const Index kBTileRow = 32; static const Index kBTileCol = 8; -/* - INTGEMM_AVX2 static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { - PrepareBFor8(input, output, avx2::QuantizeTile8(quant_mult), rows, cols); - }*/ - INTGEMM_PREPARE_B_8(INTGEMM_AVX2, avx2::QuantizeTile8) INTGEMM_AVX2 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end); } -/* - INTGEMM_AVX2 static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { - //Multiply8_SSE2OrAVX2<Multiply8_AVXAVX2, __m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols); - Multiply8_SSE2OrAVX2__m256i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); - }*/ - INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, OnAVX2) + + INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, CPUType::CPU_AVX2) constexpr static const char *const kName = "8-bit INTGEMM_AVX2"; diff --git a/avx512_gemm.h b/avx512_gemm.h index c9233a6..28b94bf 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -166,7 +166,7 @@ struct AVX512_16bit { } /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ - INTGEMM_MULTIPLY16(__m512i, INTGEMM_AVX512BW, OnAVX2) + INTGEMM_MULTIPLY16(__m512i, INTGEMM_AVX512BW, CPUType::CPU_AVX2) constexpr static const char *const kName = "16-bit AVX512"; @@ -217,8 +217,8 @@ struct AVX512_8bit { // Special AVX512 implementation due to having 32 registers (so I don't have to // allocate registers manually) and no sign instruction. - template <class WriteC> - INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { + template <typename PostprocessPipeline> + INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { typedef __m512i Integer; //typedef __m256 Float; // For quantization we only do 8 at a time. // This is copy-paste from Multiply8_SSE2OrAVX2. @@ -227,7 +227,7 @@ struct AVX512_8bit { assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); // There's 8 results for INTGEMM_AVX2 to handle. - typename WriteC::OnAVX2 write_C(C); + auto inited_pipeline = InitPostprocessPipeline<CPUType::CPU_AVX2>(pipeline); const int simd_width = width / sizeof(Integer); const Integer *B0_col = reinterpret_cast<const Integer*>(B); // Added for AVX512. @@ -324,7 +324,8 @@ struct AVX512_8bit { Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); - write_C(A_rowidx, B_cols, B0_colidx, total); + auto result = inited_pipeline.run(total); + writer(C, A_rowidx, B_cols, B0_colidx, result); } } } diff --git a/benchmark.cc b/benchmark.cc index 5e3d9d8..6477792 100644 --- a/benchmark.cc +++ b/benchmark.cc @@ -78,10 +78,10 @@ template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t> Backend::PrepareB(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols); AlignedVector<float> output(m.A_rows * m.B_cols); // Burn in - Backend::Multiply(A_prepared.begin(), B_prepared.begin(), JustUnquantizeC(output.begin(), unquant_mult), m.A_rows, m.width, m.B_cols); + Backend::Multiply(A_prepared.begin(), B_prepared.begin(), output.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), m.A_rows, m.width, m.B_cols); { StopWatch w(stats); - Backend::Multiply(A_prepared.begin(), B_prepared.begin(), JustUnquantizeC(output.begin(), unquant_mult), m.A_rows, m.width, m.B_cols); + Backend::Multiply(A_prepared.begin(), B_prepared.begin(), output.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), m.A_rows, m.width, m.B_cols); } } @@ -1,212 +0,0 @@ -#pragma once - -#include "intrinsics.h" -#include "vec_utils.h" - -#include <cassert> -#include <exception> - -namespace intgemm { - -class JustUnquantizeC { - public: - JustUnquantizeC(float *C, float unquant_mult) : C_(C), unquant_mult_(unquant_mult) {} - - class OnSSE2 { - public: - INTGEMM_SSE2 explicit OnSSE2(const JustUnquantizeC &from) - : C_(from.C_), unquant_mult_(set1_ps<__m128>(from.unquant_mult_)) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0); - } - - INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) { - storeu_ps(C_ + rowIDX*cols + colIDX , unquantize(result.pack0123, unquant_mult_)); - storeu_ps(C_ + rowIDX*cols + colIDX + 4, unquantize(result.pack4567, unquant_mult_)); - } - private: - float *C_; - __m128 unquant_mult_; - }; - - class OnAVX2 { - public: - INTGEMM_AVX2 explicit OnAVX2(const JustUnquantizeC &from) - : C_(from.C_), unquant_mult_(set1_ps<__m256>(from.unquant_mult_)) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0); - } - - INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) { - storeu_ps(C_ + rowIDX*cols + colIDX, unquantize(result, unquant_mult_)); - } - - private: - float *C_; - __m256 unquant_mult_; - }; - - private: - float *C_; - float unquant_mult_; -}; - -class BiasAddUnquantizeC { - public: - BiasAddUnquantizeC(float *C, const float *bias, float unquant_mult) : C_(C), bias_(bias), unquant_mult_(unquant_mult) {} - - class OnSSE2 { - public: - INTGEMM_SSE2 explicit OnSSE2(const BiasAddUnquantizeC &from) - : C_(from.C_), bias_(from.bias_), unquant_mult_(_mm_set1_ps(from.unquant_mult_)) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0); - } - - INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) { - auto biasSection0123 = loadu_ps<__m128>(bias_ + colIDX); - auto biasSection4567 = loadu_ps<__m128>(bias_ + colIDX + 4); - storeu_ps(C_ + rowIDX*cols + colIDX, add_ps(unquantize(result.pack0123, unquant_mult_), biasSection0123)); - storeu_ps(C_ + rowIDX*cols + colIDX + 4, add_ps(unquantize(result.pack4567, unquant_mult_), biasSection4567)); - } - private: - float *C_; - const float *bias_; - __m128 unquant_mult_; - }; - - class OnAVX2 { - public: - INTGEMM_AVX2 explicit OnAVX2(const BiasAddUnquantizeC &from) - : C_(from.C_), bias_(from.bias_), unquant_mult_(_mm256_set1_ps(from.unquant_mult_)) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0); - } - - INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) { - auto biasSection = loadu_ps<__m256>(bias_ + colIDX); - storeu_ps(C_ + rowIDX*cols + colIDX, add_ps(unquantize(result, unquant_mult_), biasSection)); - } - - private: - float *C_; - const float *bias_; - __m256 unquant_mult_; - }; - - private: - float *C_; - const float *bias_; - float unquant_mult_; -}; - -class Identity { - public: - explicit Identity(int32_t *C) : C_(C) {} - - class OnSSE2 { - public: - INTGEMM_SSE2 explicit OnSSE2(const Identity &from) - : C_(from.C_) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0); - } - - INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) { - _mm_storeu_si128(reinterpret_cast<__m128i*>(C_ + rowIDX*cols + colIDX), result.pack0123); - _mm_storeu_si128(reinterpret_cast<__m128i*>(C_ + rowIDX*cols + colIDX + 4), result.pack4567); - } - private: - int32_t *C_; - }; - - class OnAVX2 { - public: - INTGEMM_AVX2 explicit OnAVX2(const Identity &from) - : C_(from.C_) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0); - } - - INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C_ + rowIDX*cols + colIDX), result); - } - - private: - int32_t *C_; - }; - - private: - int32_t *C_; -}; - -class ReLU { - public: - explicit ReLU(float *C, float unquant_mult) : C_(C), unquant_mult_(unquant_mult) {} - - class OnSSE2 { - public: - INTGEMM_SSE2 explicit OnSSE2(const ReLU& from) - : C_(from.C_), unquant_mult_(set1_ps<__m128>(from.unquant_mult_)) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0); - } - - INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) { - static const auto zeros_ = setzero_ps<__m128>(); - - auto unquantized0123 = unquantize(result.pack0123, unquant_mult_); - auto nonnegative0123 = max_ps(zeros_, unquantized0123); - storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative0123); - - auto unquantized4567 = unquantize(result.pack4567, unquant_mult_); - auto nonnegative4567 = max_ps(zeros_, unquantized4567); - storeu_ps(C_ + rowIDX*cols + colIDX + 4, nonnegative4567); - } - - private: - float* C_; - __m128 unquant_mult_; - }; - - using OnSSSE3 = OnSSE2; - - class OnAVX2 { - public: - INTGEMM_AVX2 explicit OnAVX2(const ReLU& from) - : C_(from.C_), unquant_mult_(set1_ps<__m256>(from.unquant_mult_)) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0); - } - - INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) { - static const auto zeros_ = setzero_ps<__m256>(); - - auto nonnegative = max_ps(zeros_, unquantize(result, unquant_mult_)); - storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative); - } - - private: - float* C_; - __m256 unquant_mult_; - }; - -#ifndef INTGEMM_NO_AVX512 - class OnAVX512 { - public: - INTGEMM_AVX512BW explicit OnAVX512(const ReLU& from) - : C_(from.C_), unquant_mult_(set1_ps<__m512>(from.unquant_mult_)) { - assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m512i) == 0); - } - - INTGEMM_AVX512BW inline void operator()(Index rowIDX, Index cols, Index colIDX, __m512i result) { - static const auto zeros_ = setzero_ps<__m512>(); - - auto nonnegative = max_ps(zeros_, unquantize(result, unquant_mult_)); - storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative); - } - - private: - float* C_; - __m512 unquant_mult_; - }; -#endif - - private: - float* C_; - float unquant_mult_; -}; - -} @@ -51,7 +51,7 @@ int main() { AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. - intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), intgemm::JustUnquantizeC(C.begin(), 1.0 / (quant_mult * quant_mult)), A_rows, width, B_cols); + intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), C.begin(), intgemm::CreatePostprocessPipeline(intgemm::Unquantize(1.0 / (quant_mult * quant_mult))), A_rows, width, B_cols); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } @@ -70,7 +70,7 @@ int main() { AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. - intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), intgemm::JustUnquantizeC(C.begin(), 1.0 / (quant_mult * quant_mult)), A_rows, width, B_cols); + intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), C.begin(), intgemm::CreatePostprocessPipeline(intgemm::Unquantize(1.0 / (quant_mult * quant_mult))), A_rows, width, B_cols); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } @@ -48,7 +48,7 @@ #include "sse2_gemm.h" #include "ssse3_gemm.h" #include "avx2_gemm.h" -#include "cops.h" +#include "postprocess.h" #ifndef INTGEMM_NO_AVX512 #include "avx512_gemm.h" #endif @@ -67,8 +67,8 @@ struct Unsupported_16bit { static void SelectColumnsB(const int16_t *, int16_t *, Index, const Index *, const Index *) { throw UnsupportedCPU(); } - template<class WriteC> - static void Multiply(const int16_t *, const int16_t *, WriteC, Index, Index, Index) { + template <typename PostprocessPipeline> + static void Multiply(const int16_t *, const int16_t *, float *, PostprocessPipeline, Index, Index, Index) { throw UnsupportedCPU(); } constexpr static const char *const kName = "16-bit Unsupported"; @@ -84,8 +84,8 @@ struct Unsupported_8bit { static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) { throw UnsupportedCPU(); } - template<class WriteC> - static void Multiply(const int8_t *, const int8_t *, WriteC, Index, Index, Index) { + template <typename PostprocessPipeline> + static void Multiply(const int8_t *, const int8_t *, float *, PostprocessPipeline, Index, Index, Index) { throw UnsupportedCPU(); } constexpr static const char *const kName = "8-bit Unsupported"; @@ -133,15 +133,15 @@ template <class T> T ChooseCPU(T avx512, T avx2, T ssse3, T sse2, T unsupported) } /* 16-bit matrix multiplication. */ -template<class WriteC> +template <typename PostprocessPipeline> class Int16Mult { public: // Multiply C = A * B, presuming A and B have been prepared. - static void (*Multiply)(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols); + static void (*Multiply)(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols); }; -template <class WriteC> -void (*Int16Mult<WriteC>::Multiply)(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply<WriteC>, AVX2_16bit::Multiply<WriteC>, SSE2_16bit::Multiply<WriteC>, SSE2_16bit::Multiply<WriteC>, Unsupported_16bit::Multiply); +template <typename PostprocessPipeline> +void (*Int16Mult<PostprocessPipeline>::Multiply)(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply<PostprocessPipeline>, AVX2_16bit::Multiply<PostprocessPipeline>, SSE2_16bit::Multiply<PostprocessPipeline>, SSE2_16bit::Multiply<PostprocessPipeline>, Unsupported_16bit::Multiply); struct Int16 { typedef int16_t Integer; @@ -172,24 +172,24 @@ struct Int16 { static void (*SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end); // Multiply C = A * B, presuming A and B have been prepared. - template<class WriteC> - static void Multiply(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { - Int16Mult<WriteC>::Multiply(A, B, functor, A_rows, width, B_cols); + template <typename PostprocessPipeline> + static void Multiply(const int16_t *A, const int16_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { + Int16Mult<PostprocessPipeline>::Multiply(A, B, C, pipeline, A_rows, width, B_cols); } static const char *const kName; }; /* 8-bit matrix multiplication */ -template<class WriteC> +template <typename PostprocessPipeline> class Int8Mult { public: // Multiply C = A * B, presuming A and B have been prepared. - static void (*Multiply)(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols); + static void (*Multiply)(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols); }; -template <class WriteC> -void (*Int8Mult<WriteC>::Multiply)(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply<WriteC>, AVX2_8bit::Multiply<WriteC>, SSSE3_8bit::Multiply<WriteC>, SSSE3_8bit::Multiply<WriteC>, Unsupported_8bit::Multiply); +template <typename PostprocessPipeline> +void (*Int8Mult<PostprocessPipeline>::Multiply)(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply<PostprocessPipeline>, AVX2_8bit::Multiply<PostprocessPipeline>, SSSE3_8bit::Multiply<PostprocessPipeline>, SSSE3_8bit::Multiply<PostprocessPipeline>, Unsupported_8bit::Multiply); struct Int8 { typedef int8_t Integer; @@ -219,9 +219,9 @@ struct Int8 { static void (*SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end); // Multiply C = A * B, presuming A and B have been prepared. - template<class WriteC> - static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { - Int8Mult<WriteC>::Multiply(A, B, functor, A_rows, width, B_cols); + template <typename PostprocessPipeline> + static void Multiply(const int8_t *A, const int8_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { + Int8Mult<PostprocessPipeline>::Multiply(A, B, C, pipeline, A_rows, width, B_cols); } static const char *const kName; @@ -2,10 +2,30 @@ #include "interleave.h" #include "intrinsics.h" +#include "postprocess_pipeline.h" #include "vec_utils.h" namespace intgemm { +INTGEMM_SSE2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m128 result) { + *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX) = result; +} + +INTGEMM_SSE2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, RegisterPair128 result) { + *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX) = result.pack0123; + *reinterpret_cast<__m128*>(C + rowIDX*cols + colIDX + 4) = result.pack4567; +} + +INTGEMM_AVX2 static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m256 result) { + *reinterpret_cast<__m256*>(C + rowIDX*cols + colIDX) = result; +} + +#ifndef INTGEMM_NO_AVX512 +INTGEMM_AVX512BW static inline void writer(float* C, Index rowIDX, Index cols, Index colIDX, __m512 result) { + *reinterpret_cast<__m512*>(C + rowIDX*cols + colIDX) = result; +} +#endif + INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) { // Fold to just using the first 64 bits. __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2); @@ -17,9 +37,9 @@ INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) { return *reinterpret_cast<float*>(&a); } -INTGEMM_SSE2 static inline MultiplyResult128 PermuteSummer(__m128i pack0123, __m128i pack4567) { +INTGEMM_SSE2 static inline RegisterPair128i PermuteSummer(__m128i pack0123, __m128i pack4567) { // No op for 128 bits: already reduced fully. - MultiplyResult128 ret; + RegisterPair128i ret; ret.pack0123 = pack0123; ret.pack4567 = pack4567; return ret; @@ -126,14 +146,14 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i) // width must be a multiple of the register size. // B_cols must be a multiple of 8. // Multiply16 -#define INTGEMM_MULTIPLY16(Integer, target, WriteCSubType) \ - template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \ +#define INTGEMM_MULTIPLY16(Integer, target, cpu_type) \ +template <typename PostprocessPipeline> target static void Multiply(const int16_t *A, const int16_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \ assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \ - typename WriteC::WriteCSubType write_C(C); \ + auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \ const Integer *B0_col = reinterpret_cast<const Integer *>(B); \ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ @@ -177,7 +197,8 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i) Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ /*The specific implementation may need to reduce further.*/ \ auto total = PermuteSummer(pack0123, pack4567); \ - write_C(A_rowidx, B_cols, B0_colidx, total); \ + auto result = inited_pipeline.run(total); \ + writer(C, A_rowidx, B_cols, B0_colidx, result); \ } \ } \ } \ @@ -330,15 +351,15 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a))); } //INTGEMM_AVX2 or INTGEMM_SSSE3 multiply -#define INTGEMM_MULTIPLY8(Integer, target, WriteCSubType) \ -template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC C, Index A_rows, Index width, Index B_cols) { \ +#define INTGEMM_MULTIPLY8(Integer, target, cpu_type) \ + template <typename PostprocessPipeline> target static void Multiply(const int8_t *A, const int8_t *B, float* C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { \ assert(width % sizeof(Integer) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ const int simd_width = width / sizeof(Integer); \ const Integer *B0_col = reinterpret_cast<const Integer*>(B); \ - typename WriteC::WriteCSubType c_writer(C); \ + auto inited_pipeline = InitPostprocessPipeline<cpu_type>(pipeline); \ /*Go over 8 columns of B at a time.*/ \ for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ /*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ @@ -393,8 +414,8 @@ template <class WriteC> target static void Multiply(const int8_t *A, const int8_ Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \ auto total = PermuteSummer(pack0123, pack4567); \ - /*WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);*/ \ - c_writer(A_rowidx, B_cols, B0_colidx, total); \ + auto result = inited_pipeline.run(total); \ + writer(C, A_rowidx, B_cols, B0_colidx, result); \ } \ } \ } \ diff --git a/postprocess.h b/postprocess.h new file mode 100644 index 0000000..7336819 --- /dev/null +++ b/postprocess.h @@ -0,0 +1,221 @@ +#pragma once + +#include "intrinsics.h" +#include "postprocess_pipeline.h" +#include "vec_utils.h" + +namespace intgemm { + +/* + * Unquantize + */ +class Unquantize { +public: + float unquantize_multiplier; + + Unquantize(float unquantize_multiplier) : unquantize_multiplier(unquantize_multiplier) {} +}; + +template <> +class PostprocessImpl<Unquantize, CPUType::CPU_SSE2> { +public: + using InputRegister = RegisterPair128i; + using OutputRegister = RegisterPair128; + + INTGEMM_SSE2 PostprocessImpl(const Unquantize& config) { + unquantize_multiplier = set1_ps<__m128>(config.unquantize_multiplier); + } + + INTGEMM_SSE2 inline OutputRegister run(InputRegister input) { + return { + mul_ps(cvtepi32_ps(input.pack0123), unquantize_multiplier), + mul_ps(cvtepi32_ps(input.pack4567), unquantize_multiplier), + }; + } + +private: + __m128 unquantize_multiplier; +}; + +template <> +class PostprocessImpl<Unquantize, CPUType::CPU_AVX2> { +public: + using InputRegister = __m256i; + using OutputRegister = __m256; + + INTGEMM_AVX2 PostprocessImpl(const Unquantize& config) { + unquantize_multiplier = set1_ps<__m256>(config.unquantize_multiplier); + } + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input) { + return mul_ps(cvtepi32_ps(input), unquantize_multiplier); + } + +private: + __m256 unquantize_multiplier; +}; + +template <> +class PostprocessImpl<Unquantize, CPUType::CPU_AVX512BW> { +public: + using InputRegister = __m512i; + using OutputRegister = __m512; + + INTGEMM_AVX512BW PostprocessImpl(const Unquantize& config) { + unquantize_multiplier = set1_ps<__m512>(config.unquantize_multiplier); + } + + INTGEMM_AVX512BW inline OutputRegister run(InputRegister input) { + return mul_ps(cvtepi32_ps(input), unquantize_multiplier); + } + +private: + __m512 unquantize_multiplier; +}; + +/* + * Identity + */ +class Identity {}; + +template <> +class PostprocessImpl<Identity, CPUType::CPU_SSE2> { +public: + using InputRegister = RegisterPair128i; + using OutputRegister = RegisterPair128i; + + PostprocessImpl(const Identity& config) {} + + INTGEMM_SSE2 inline OutputRegister run(InputRegister input) { + return input; + } +}; + +template <> +class PostprocessImpl<Identity, CPUType::CPU_AVX2> { +public: + using InputRegister = __m256i; + using OutputRegister = __m256i; + + PostprocessImpl(const Identity& config) {} + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input) { + return input; + } +}; + +template <> +class PostprocessImpl<Identity, CPUType::CPU_AVX512BW> { +public: + using InputRegister = __m512i; + using OutputRegister = __m512i; + + PostprocessImpl(const Identity& config) {} + + INTGEMM_AVX512BW inline OutputRegister run(InputRegister input) { + return input; + } +}; + +/* + * Add a bias term + */ +class AddBias { +public: + const float* bias; + + AddBias(const float* bias) : bias(bias) {} +}; + +template <> +class PostprocessImpl<AddBias, CPUType::CPU_SSE2> { +public: + using InputRegister = RegisterPair128; + using OutputRegister = RegisterPair128; + + PostprocessImpl(const AddBias& config) : config(config) {} + + INTGEMM_SSE2 inline OutputRegister run(InputRegister input) { + auto bias_term0123 = *reinterpret_cast<const __m128*>(config.bias); + auto bias_term4567 = *reinterpret_cast<const __m128*>(config.bias); + return { + add_ps(input.pack0123, bias_term0123), + add_ps(input.pack4567, bias_term4567), + }; + } + +private: + const AddBias config; +}; + +template <> +class PostprocessImpl<AddBias, CPUType::CPU_AVX2> { +public: + using InputRegister = __m256; + using OutputRegister = __m256; + + PostprocessImpl(const AddBias& config) : config(config) {} + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input) { + auto bias_term = *reinterpret_cast<const __m256*>(config.bias); + return add_ps(input, bias_term); + } + +private: + const AddBias config; +}; + +/* + * ReLU + */ +class ReLU {}; + +template <> +class PostprocessImpl<ReLU, CPUType::CPU_SSE2> { +public: + using InputRegister = RegisterPair128; + using OutputRegister = RegisterPair128; + + PostprocessImpl(const ReLU& config) {} + + INTGEMM_SSE2 inline OutputRegister run(InputRegister input) { + static const auto const_zero = set1_ps<__m128>(0.f); + return { + max_ps(const_zero, input.pack0123), + max_ps(const_zero, input.pack4567), + }; + } +}; + +template <> +class PostprocessImpl<ReLU, CPUType::CPU_SSSE3> : public PostprocessImpl<ReLU, CPUType::CPU_SSE2> {}; + +template <> +class PostprocessImpl<ReLU, CPUType::CPU_AVX2> { +public: + using InputRegister = __m256; + using OutputRegister = __m256; + + PostprocessImpl(const ReLU& config) {} + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input) { + static const auto const_zero = set1_ps<__m256>(0.f); + return max_ps(const_zero, input); + } +}; + +template <> +class PostprocessImpl<ReLU, CPUType::CPU_AVX512BW> { +public: + using InputRegister = __m512; + using OutputRegister = __m512; + + PostprocessImpl(const ReLU& config) {} + + INTGEMM_AVX512BW inline OutputRegister run(InputRegister input) { + static const auto const_zero = set1_ps<__m512>(0.f); + return max_ps(const_zero, input); + } +}; + +} diff --git a/postprocess_pipeline.h b/postprocess_pipeline.h new file mode 100644 index 0000000..0783ff0 --- /dev/null +++ b/postprocess_pipeline.h @@ -0,0 +1,140 @@ +#pragma once + +#include "intrinsics.h" +#include "types.h" + +#include <tuple> + +namespace intgemm { + +template <typename... Stages> +using PostprocessPipeline = std::tuple<Stages...>; + +template <typename... Stages> +constexpr std::tuple<Stages...> CreatePostprocessPipeline(const Stages&... stages) { + return std::tuple<Stages...>(stages...); +} + +template <typename Postprocess, CPUType CpuType> +class PostprocessImpl; + +namespace { // anonymous namespace + +template <std::size_t... I> +struct integer_seq {}; + +template <std::size_t N, std::size_t... I> +struct integer_seq_from_one_s : integer_seq_from_one_s<N - 1, N - 1, I...> {}; + +template <std::size_t... I> +struct integer_seq_from_one_s<1, I...> : integer_seq<I...> {}; + +template <typename... Types> +using integer_seq_from_one = integer_seq_from_one_s<sizeof...(Types) + 1>; + +template <typename Stage> +struct remove_first_stage_type_s { using type = std::tuple<>;}; + +template <typename FirstStage, typename... RestStages> +struct remove_first_stage_type_s<std::tuple<FirstStage, RestStages...>> { using type = std::tuple<RestStages...>; }; + +template <typename... Stages> +using remove_first_stage_type = typename remove_first_stage_type_s<Stages...>::type; + +template <typename FirstStage, typename... RestStages> +struct get_first_stage_type_s { using type = FirstStage; }; + +template <typename Stage> +struct get_first_stage_type_s<Stage> { using type = Stage; }; + +template <typename... Stages> +using get_first_stage_type = typename get_first_stage_type_s<Stages...>::type; + +template <typename FirstStage, typename... RestStages> +struct get_last_stage_type_s { using type = typename get_last_stage_type_s<RestStages...>::type; }; + +template <typename Stage> +struct get_last_stage_type_s<Stage> { using type = Stage; }; + +template <typename... Stages> +using get_last_stage_type = typename get_last_stage_type_s<Stages...>::type; + +template <typename Tuple, typename std::size_t...I> +constexpr remove_first_stage_type<Tuple> ShiftPostprocessPipelineImpl(const Tuple& pipeline, integer_seq<I...>) { + return CreatePostprocessPipeline(std::get<I>(pipeline)...); +} + +template <typename FirstStage, typename... RestStages> +constexpr std::tuple<RestStages...> ShiftPostprocessPipeline(const std::tuple<FirstStage, RestStages...>& pipeline) { + return ShiftPostprocessPipelineImpl(pipeline, integer_seq_from_one<std::tuple<FirstStage, RestStages...>>()); +} + +template <CPUType CpuType, typename Stage> +constexpr std::tuple<PostprocessImpl<Stage, CpuType>> InitPostprocessPipelineImpl(std::tuple<Stage> pipeline) { + return std::tuple<PostprocessImpl<Stage, CpuType>>(PostprocessImpl<Stage, CpuType>(std::get<0>(pipeline))); +} + +template <CPUType CpuType, typename FirstStage, typename... RestStages> +constexpr std::tuple<PostprocessImpl<FirstStage, CpuType>, PostprocessImpl<RestStages, CpuType>...> InitPostprocessPipelineImpl(std::tuple<FirstStage, RestStages...> pipeline) { + return std::tuple_cat( + std::tuple<PostprocessImpl<FirstStage, CpuType>>(PostprocessImpl<FirstStage, CpuType>(std::get<0>(pipeline))), + InitPostprocessPipelineImpl<CpuType, RestStages...>(ShiftPostprocessPipeline(pipeline)) + ); +} + +template <CPUType CpuType> +struct RunPostprocessPipelineImpl; + +#define RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(attribute, cpu_type) \ + template <> \ + struct RunPostprocessPipelineImpl<cpu_type> { \ + template <typename Stage> \ + attribute static constexpr typename Stage::OutputRegister \ + run(std::tuple<Stage> pipeline, typename Stage::InputRegister input) { \ + return std::get<0>(pipeline).run(input); \ + } \ + template <typename... Stages> \ + attribute static constexpr typename get_last_stage_type<Stages...>::OutputRegister \ + run(std::tuple<Stages...> pipeline, typename get_first_stage_type<Stages...>::InputRegister input) { \ + return run( \ + ShiftPostprocessPipeline(pipeline), \ + std::get<0>(pipeline).run(input)); \ + } \ + }; + +RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSE2, CPUType::CPU_SSE2) +RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSSE3, CPUType::CPU_SSSE3) +RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX2, CPUType::CPU_AVX2) +RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::CPU_AVX512BW) + +} // anonymous namespace + +template <CPUType CpuType, typename... Stages> +class InitedPostprocessPipeline {}; + +template <CPUType CpuType, typename... Stages> +constexpr InitedPostprocessPipeline<CpuType, Stages...> InitPostprocessPipeline(std::tuple<Stages...> pipeline) { + return InitedPostprocessPipeline<CpuType, Stages...>(pipeline); +} + +#define INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(attribute, cpu_type) \ + template <typename... Stages> \ + class InitedPostprocessPipeline<cpu_type, Stages...> { \ + public: \ + using InputRegister = typename get_first_stage_type<PostprocessImpl<Stages, cpu_type>...>::InputRegister; \ + using OutputRegister = typename get_last_stage_type<PostprocessImpl<Stages, cpu_type>...>::OutputRegister; \ + InitedPostprocessPipeline(std::tuple<Stages...> pipeline) \ + : inited_pipeline(InitPostprocessPipelineImpl<cpu_type, Stages...>(pipeline)) {} \ + attribute inline OutputRegister run(InputRegister input) { \ + return RunPostprocessPipelineImpl<cpu_type>::run(inited_pipeline, input); \ + } \ + private: \ + const std::tuple<PostprocessImpl<Stages, cpu_type>...> inited_pipeline; \ + }; + +INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSE2, CPUType::CPU_SSE2) +INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSSE3, CPUType::CPU_SSSE3) +INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX2, CPUType::CPU_AVX2) +INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::CPU_AVX512BW) + +} diff --git a/sse2_gemm.h b/sse2_gemm.h index 3ca263f..dfccc5c 100644 --- a/sse2_gemm.h +++ b/sse2_gemm.h @@ -72,7 +72,7 @@ struct SSE2_16bit { //TODO #DEFINE sse2::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end); } - INTGEMM_MULTIPLY16(__m128i, INTGEMM_SSE2, OnSSE2) + INTGEMM_MULTIPLY16(__m128i, INTGEMM_SSE2, CPUType::CPU_SSE2) constexpr static const char *const kName = "16-bit INTGEMM_SSE2"; diff --git a/ssse3_gemm.h b/ssse3_gemm.h index 9c21467..4e12b90 100644 --- a/ssse3_gemm.h +++ b/ssse3_gemm.h @@ -88,21 +88,14 @@ struct SSSE3_8bit { // Tile size for B; B must be a multiple of this block size. static const Index kBTileRow = 16; static const Index kBTileCol = 8; -/* - INTGEMM_SSSE3 static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { - PrepareBFor8(input, output, ssse3::QuantizeTile8(quant_mult), rows, cols); - }*/ + INTGEMM_PREPARE_B_8(INTGEMM_SSSE3, ssse3::QuantizeTile8) INTGEMM_SSSE3 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) { ssse3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end); } -/* - INTGEMM_SSSE3 static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) { - //Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols); - Multiply8_SSE2OrAVX2__m128i<JustUnquantizeC>(A, B, JustUnquantizeC(C, unquant_mult), A_rows, width, B_cols); - }*/ - INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, OnSSE2) + + INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::CPU_SSE2) constexpr static const char *const kName = "8-bit INTGEMM_SSSE3"; diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 0cfec02..0c0becf 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -1,7 +1,8 @@ -#include "intgemm.h" #include "aligned.h" #include "interleave.h" +#include "intgemm.h" #include "multiply.h" +#include "postprocess.h" #define CATCH_CONFIG_RUNNER #include "3rd_party/catch.hpp" @@ -366,7 +367,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); - Routine::Multiply(A_prep.begin(), B_prep.begin(), JustUnquantizeC(test_C.begin(), unquant_mult), A_rows, width, B_cols); + Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult)), A_rows, width, B_cols); AlignedVector<Integer> B_quant(B.size()); Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); @@ -415,7 +416,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index AlignedVector<float> test_C(A_rows * B_cols); - Routine::Multiply(A_prep.begin(), B_prep.begin(), BiasAddUnquantizeC(test_C.begin(), bias.begin(), unquant_mult), A_rows, width, B_cols); + Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult), AddBias(bias.begin())), A_rows, width, B_cols); AlignedVector<Integer> B_quant(B.size()); Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); diff --git a/test/pipeline_test.cc b/test/pipeline_test.cc new file mode 100644 index 0000000..dc3d71e --- /dev/null +++ b/test/pipeline_test.cc @@ -0,0 +1,32 @@ +#include "3rd_party/catch.hpp" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") { + if (kCPU < CPU_AVX2) + return; + + int raw_input[8]; + std::iota(raw_input, raw_input + 8, -2); + + auto input = *reinterpret_cast<__m256i*>(raw_input); + auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU()); + auto inited_pipeline = InitPostprocessPipeline<CPU_AVX2>(pipeline); + auto output = inited_pipeline.run(input); + + float* raw_output = reinterpret_cast<float*>(&output); + + CHECK(raw_output[0] == 0.0f); // input = -2 + CHECK(raw_output[1] == 0.0f); // input = -1 + CHECK(raw_output[2] == 0.0f); // input = 0 + CHECK(raw_output[3] == 0.5f); // input = 1 + CHECK(raw_output[4] == 1.0f); // input = 2 + CHECK(raw_output[5] == 1.5f); // input = 3 + CHECK(raw_output[6] == 2.0f); // input = 4 + CHECK(raw_output[7] == 2.5f); // input = 5 +} + +} diff --git a/test/relu_test.cc b/test/relu_test.cc index 264cd5a..0c677c9 100644 --- a/test/relu_test.cc +++ b/test/relu_test.cc @@ -1,5 +1,5 @@ #include "3rd_party/catch.hpp" -#include "cops.h" +#include "postprocess.h" #include <numeric> @@ -8,56 +8,48 @@ namespace intgemm { INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) { if (kCPU < CPU_SSE2) return; - const unsigned N = 4; - int32_t raw_input[2 * N]; - std::iota(raw_input, raw_input + 2 * N, -2); - - MultiplyResult128 input; - input.pack0123 = *reinterpret_cast<__m128i*>(raw_input); - input.pack4567 = *reinterpret_cast<__m128i*>(raw_input + N); - - float output[2 * N]; - std::fill(output, output + 2 * N, 42); - - auto postproc = ReLU::OnSSE2(ReLU(output, 1.f)); - postproc(0, 1, 0, input); - - CHECK(output[0] == 0.f); // input = -2 - CHECK(output[1] == 0.f); // input = -1 - CHECK(output[2] == 0.f); // input = 0 - CHECK(output[3] == 1.f); // input = 1 - CHECK(output[4] == 2.f); // input = 2 - CHECK(output[5] == 3.f); // input = 3 - CHECK(output[6] == 4.f); // input = 4 - CHECK(output[7] == 5.f); // input = 5 + float raw_input[8]; + std::iota(raw_input, raw_input + 8, -2); + + RegisterPair128 input; + input.pack0123 = *reinterpret_cast<__m128*>(raw_input); + input.pack4567 = *reinterpret_cast<__m128*>(raw_input + 4); + + auto postproc = PostprocessImpl<ReLU, CPUType::CPU_SSE2>(ReLU()); + auto output = postproc.run(input); + auto raw_output = reinterpret_cast<float*>(&output); + + CHECK(raw_output[0] == 0.f); // input = -2 + CHECK(raw_output[1] == 0.f); // input = -1 + CHECK(raw_output[2] == 0.f); // input = 0 + CHECK(raw_output[3] == 1.f); // input = 1 + CHECK(raw_output[4] == 2.f); // input = 2 + CHECK(raw_output[5] == 3.f); // input = 3 + CHECK(raw_output[6] == 4.f); // input = 4 + CHECK(raw_output[7] == 5.f); // input = 5 } INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) { if (kCPU < CPU_AVX2) return; - const unsigned N = 8; - - int32_t raw_input[N]; - std::iota(raw_input, raw_input + N, -4); - - auto input = *reinterpret_cast<__m256i*>(raw_input); - - float output[N]; - std::fill(output, output + N, 42); - - auto postproc = ReLU::OnAVX2(ReLU(output, 1.f)); - postproc(0, 1, 0, input); - - CHECK(output[0] == 0.f); // input = -4 - CHECK(output[1] == 0.f); // input = -3 - CHECK(output[2] == 0.f); // input = -2 - CHECK(output[3] == 0.f); // input = -1 - CHECK(output[4] == 0.f); // input = 0 - CHECK(output[5] == 1.f); // input = 1 - CHECK(output[6] == 2.f); // input = 2 - CHECK(output[7] == 3.f); // input = 3 + float raw_input[8]; + std::iota(raw_input, raw_input + 8, -4); + + auto input = *reinterpret_cast<__m256*>(raw_input); + auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX2>(ReLU()); + auto output = postproc.run(input); + auto raw_output = reinterpret_cast<float*>(&output); + + CHECK(raw_output[0] == 0.f); // input = -4 + CHECK(raw_output[1] == 0.f); // input = -3 + CHECK(raw_output[2] == 0.f); // input = -2 + CHECK(raw_output[3] == 0.f); // input = -1 + CHECK(raw_output[4] == 0.f); // input = 0 + CHECK(raw_output[5] == 1.f); // input = 1 + CHECK(raw_output[6] == 2.f); // input = 2 + CHECK(raw_output[7] == 3.f); // input = 3 } #ifndef INTGEMM_NO_AVX512 @@ -66,35 +58,30 @@ INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) { if (kCPU < CPU_AVX512BW) return; - const unsigned N = 16; - - int32_t raw_input[N]; - std::iota(raw_input, raw_input + N, -8); - - auto input = *reinterpret_cast<__m512i*>(raw_input); - - float output[N]; - std::fill(output, output + N, 42); - - auto postproc = ReLU::OnAVX512(ReLU(output, 1.f)); - postproc(0, 1, 0, input); - - CHECK(output[0] == 0.f); // input = -8 - CHECK(output[1] == 0.f); // input = -7 - CHECK(output[2] == 0.f); // input = -6 - CHECK(output[3] == 0.f); // input = -5 - CHECK(output[4] == 0.f); // input = -4 - CHECK(output[5] == 0.f); // input = -3 - CHECK(output[6] == 0.f); // input = -2 - CHECK(output[7] == 0.f); // input = -1 - CHECK(output[8] == 0.f); // input = 0 - CHECK(output[9] == 1.f); // input = 1 - CHECK(output[10] == 2.f); // input = 2 - CHECK(output[11] == 3.f); // input = 3 - CHECK(output[12] == 4.f); // input = 4 - CHECK(output[13] == 5.f); // input = 5 - CHECK(output[14] == 6.f); // input = 6 - CHECK(output[15] == 7.f); // input = 7 + float raw_input[16]; + std::iota(raw_input, raw_input + 16, -8); + + auto input = *reinterpret_cast<__m512*>(raw_input); + auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX512BW>(ReLU()); + auto output = postproc.run(input); + auto raw_output = reinterpret_cast<float*>(&output); + + CHECK(raw_output[0] == 0.f); // input = -8 + CHECK(raw_output[1] == 0.f); // input = -7 + CHECK(raw_output[2] == 0.f); // input = -6 + CHECK(raw_output[3] == 0.f); // input = -5 + CHECK(raw_output[4] == 0.f); // input = -4 + CHECK(raw_output[5] == 0.f); // input = -3 + CHECK(raw_output[6] == 0.f); // input = -2 + CHECK(raw_output[7] == 0.f); // input = -1 + CHECK(raw_output[8] == 0.f); // input = 0 + CHECK(raw_output[9] == 1.f); // input = 1 + CHECK(raw_output[10] == 2.f); // input = 2 + CHECK(raw_output[11] == 3.f); // input = 3 + CHECK(raw_output[12] == 4.f); // input = 4 + CHECK(raw_output[13] == 5.f); // input = 5 + CHECK(raw_output[14] == 6.f); // input = 6 + CHECK(raw_output[15] == 7.f); // input = 7 } #endif diff --git a/vec_utils.h b/vec_utils.h index 3574ba0..fb6aea4 100644 --- a/vec_utils.h +++ b/vec_utils.h @@ -4,10 +4,14 @@ namespace intgemm { -struct MultiplyResult128 { +struct RegisterPair128i { __m128i pack0123, pack4567; }; +struct RegisterPair128 { + __m128 pack0123, pack4567; +}; + /* * * Quantize |