Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/cops.h
diff options
context:
space:
mode:
Diffstat (limited to 'cops.h')
-rw-r--r--cops.h17
1 files changed, 10 insertions, 7 deletions
diff --git a/cops.h b/cops.h
index 60dd206..67796f5 100644
--- a/cops.h
+++ b/cops.h
@@ -141,11 +141,13 @@ class ReLU {
class OnSSE2 {
public:
INTGEMM_SSE2 explicit OnSSE2(const ReLU& from)
- : C_(from.C_), zeros_(setzero_ps<__m128>()), unquant_mult_(set1_ps<__m128>(from.unquant_mult_)) {
+ : C_(from.C_), unquant_mult_(set1_ps<__m128>(from.unquant_mult_)) {
assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m128i) == 0);
}
INTGEMM_SSE2 inline void operator()(Index rowIDX, Index cols, Index colIDX, MultiplyResult128 result) {
+ static const auto zeros_ = setzero_ps<__m128>();
+
auto unquantized0123 = unquantize(result.pack0123, unquant_mult_);
auto nonnegative0123 = max_ps(zeros_, unquantized0123);
storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative0123);
@@ -158,19 +160,20 @@ class ReLU {
private:
float* C_;
__m128 unquant_mult_;
- __m128 zeros_;
};
- using OnSSSE2 = OnSSE2;
+ using OnSSSE3 = OnSSE2;
class OnAVX2 {
public:
INTGEMM_AVX2 explicit OnAVX2(const ReLU& from)
- : C_(from.C_), zeros_(setzero_ps<__m256>()), unquant_mult_(set1_ps<__m256>(from.unquant_mult_)) {
+ : C_(from.C_), unquant_mult_(set1_ps<__m256>(from.unquant_mult_)) {
assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m256i) == 0);
}
INTGEMM_AVX2 inline void operator()(Index rowIDX, Index cols, Index colIDX, __m256i result) {
+ static const auto zeros_ = setzero_ps<__m256>();
+
auto nonnegative = max_ps(zeros_, unquantize(result, unquant_mult_));
storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative);
}
@@ -178,18 +181,19 @@ class ReLU {
private:
float* C_;
__m256 unquant_mult_;
- __m256 zeros_;
};
#ifndef INTGEMM_NO_AVX512
class OnAVX512 {
public:
INTGEMM_AVX512BW explicit OnAVX512(const ReLU& from)
- : C_(from.C_), zeros_(setzero_ps<__m512>()), unquant_mult_(set1_ps<__m512>(from.unquant_mult_)) {
+ : C_(from.C_), unquant_mult_(set1_ps<__m512>(from.unquant_mult_)) {
assert(reinterpret_cast<uintptr_t>(C_) % sizeof(__m512i) == 0);
}
INTGEMM_AVX512BW inline void operator()(Index rowIDX, Index cols, Index colIDX, __m512i result) {
+ static const auto zeros_ = setzero_ps<__m512>();
+
auto nonnegative = max_ps(zeros_, unquantize(result, unquant_mult_));
storeu_ps(C_ + rowIDX*cols + colIDX, nonnegative);
}
@@ -197,7 +201,6 @@ class ReLU {
private:
float* C_;
__m512 unquant_mult_;
- __m512 zeros_;
};
#endif