diff options
Diffstat (limited to 'cops.h')
-rw-r--r-- | cops.h | 17 |
1 files changed, 10 insertions, 7 deletions
@@ -141,11 +141,13 @@ class ReLU { class OnSSE2 { public: INTGEMM_SSE2 explicit OnSSE2(const ReLU& from) - : C_(from.C_), zeros_(setzero_ps<__m128>()), unquant_mult_(set1_ps<__m128>(from.unquant_mult_)) { + : C_(from.C_), unquant_mult_(set1_ps<__m128>(from.unquant_mult_)) { assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0); } INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) { + static const auto zeros_ = setzero_ps<__m128>(); + auto unquantized0123 = unquantize(result.pack0123, unquant_mult_); auto nonnegative0123 = max_ps(zeros_, unquantized0123); storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative0123); @@ -158,19 +160,20 @@ class ReLU { private: float* C_; __m128 unquant_mult_; - __m128 zeros_; }; - using OnSSSE2 = OnSSE2; + using OnSSSE3 = OnSSE2; class OnAVX2 { public: INTGEMM_AVX2 explicit OnAVX2(const ReLU& from) - : C_(from.C_), zeros_(setzero_ps<__m256>()), unquant_mult_(set1_ps<__m256>(from.unquant_mult_)) { + : C_(from.C_), unquant_mult_(set1_ps<__m256>(from.unquant_mult_)) { assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0); } INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) { + static const auto zeros_ = setzero_ps<__m256>(); + auto nonnegative = max_ps(zeros_, unquantize(result, unquant_mult_)); storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative); } @@ -178,18 +181,19 @@ class ReLU { private: float* C_; __m256 unquant_mult_; - __m256 zeros_; }; #ifndef INTGEMM_NO_AVX512 class OnAVX512 { public: INTGEMM_AVX512BW explicit OnAVX512(const ReLU& from) - : C_(from.C_), zeros_(setzero_ps<__m512>()), unquant_mult_(set1_ps<__m512>(from.unquant_mult_)) { + : C_(from.C_), unquant_mult_(set1_ps<__m512>(from.unquant_mult_)) { assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m512i) == 0); } INTGEMM_AVX512BW inline void operator()(Index rowIDX, Index cols, Index colIDX, __m512i result) { + static const auto zeros_ = setzero_ps<__m512>(); + auto nonnegative = max_ps(zeros_, unquantize(result, unquant_mult_)); storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative); } @@ -197,7 +201,6 @@ class ReLU { private: float* C_; __m512 unquant_mult_; - __m512 zeros_; }; #endif |