Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt1
-rw-r--r--intrinsics.h9
-rw-r--r--postprocess.h41
-rw-r--r--test/tanh_test.cc42
4 files changed, 93 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4948bbd..2061bf2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -37,6 +37,7 @@ add_executable(tests
test/quantize_test.cc
test/relu_test.cc
test/sigmoid_test.cc
+ test/tanh_test.cc
intgemm.cc
)
diff --git a/intrinsics.h b/intrinsics.h
index 8e4bc27..9337e35 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -51,6 +51,9 @@ INTGEMM_SSE2 static inline __m128i cvtps_epi32(__m128 arg) {
INTGEMM_SSE2 static inline __m128i cvttps_epi32(__m128 a) {
return _mm_cvttps_epi32(a);
}
+INTGEMM_SSE2 static inline __m128 div_ps(__m128 a, __m128 b) {
+ return _mm_div_ps(a, b);
+}
/*
* Missing i32gather_ps for SSE2
*/
@@ -126,6 +129,9 @@ INTGEMM_AVX2 static inline __m256i cvtps_epi32(__m256 arg) {
INTGEMM_AVX2 static inline __m256i cvttps_epi32(__m256 a) {
return _mm256_cvttps_epi32(a);
}
+INTGEMM_AVX2 static inline __m256 div_ps(__m256 a, __m256 b) {
+ return _mm256_div_ps(a, b);
+}
INTGEMM_AVX2 static inline __m256 i32gather_ps(float const *base_addr, __m256i vindex, const int scale) {
return _mm256_i32gather_ps(base_addr, vindex, scale);
}
@@ -203,6 +209,9 @@ INTGEMM_AVX512BW static inline __m512i cvtps_epi32(__m512 arg) {
INTGEMM_AVX512BW static inline __m512i cvttps_epi32(__m512 a) {
return _mm512_cvttps_epi32(a);
}
+INTGEMM_AVX512BW static inline __m512 div_ps(__m512 a, __m512 b) {
+ return _mm512_div_ps(a, b);
+}
INTGEMM_AVX512BW static inline __m512 i32gather_ps(float const *base_addr, __m512i vindex, const int scale) {
return _mm512_i32gather_ps(vindex, base_addr, scale);
}
diff --git a/postprocess.h b/postprocess.h
index 2df946d..ad9c290 100644
--- a/postprocess.h
+++ b/postprocess.h
@@ -250,4 +250,45 @@ public:
}
};
+/*
+ * Tanh (uses Taylor series approximation of e^x)
+ */
+class Tanh {};
+
+template <>
+class PostprocessImpl<Tanh, CPUType::AVX2> {
+public:
+ using InputRegister = __m256;
+ using OutputRegister = __m256;
+
+ PostprocessImpl(const Tanh& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
+ const static auto const_zero = setzero_ps<__m256>();
+
+ auto e_x = exp_approx_taylor(input);
+ auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input));
+
+ return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x));
+ }
+};
+
+template <>
+class PostprocessImpl<Tanh, CPUType::AVX512BW> {
+public:
+ using InputRegister = __m512;
+ using OutputRegister = __m512;
+
+ PostprocessImpl(const Tanh& config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
+ const static auto const_zero = setzero_ps<__m512>();
+
+ auto e_x = exp_approx_taylor(input);
+ auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input));
+
+ return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x));
+ }
+};
+
}
diff --git a/test/tanh_test.cc b/test/tanh_test.cc
new file mode 100644
index 0000000..72e2555
--- /dev/null
+++ b/test/tanh_test.cc
@@ -0,0 +1,42 @@
+#include "3rd_party/catch.hpp"
+#include "postprocess.h"
+
+#include <numeric>
+
+#define CHECK_FLOAT(actual, expected, epsilon) \
+ do { \
+ if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
+ else { CHECK((actual) == (expected)); } \
+ } while(0)
+
+namespace intgemm {
+
+INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const float error_tolerance = 0.001f;
+
+ __m256 input;
+
+ { // fill
+ auto raw = reinterpret_cast<float*>(&input);
+ int n = -4;
+ std::generate(raw, raw + 8, [&n] () { return n++ / 4.f; });
+ }
+
+ auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh());
+ auto output = postproc.run(input, 0);
+ auto raw_output = reinterpret_cast<float*>(&output);
+
+ CHECK_FLOAT(raw_output[0], -0.7615942f, error_tolerance); // input = -1
+ CHECK_FLOAT(raw_output[1], -0.6351490f, error_tolerance); // input = -0.75
+ CHECK_FLOAT(raw_output[2], -0.4621172f, error_tolerance); // input = -0.5
+ CHECK_FLOAT(raw_output[3], -0.2449187f, error_tolerance); // input = -0.25
+ CHECK_FLOAT(raw_output[4], 0.0f , error_tolerance); // input = 0
+ CHECK_FLOAT(raw_output[5], 0.2449187f, error_tolerance); // input = 0.25
+ CHECK_FLOAT(raw_output[6], 0.4621172f, error_tolerance); // input = 0.5
+ CHECK_FLOAT(raw_output[7], 0.6351490f, error_tolerance); // input = 0.75
+}
+
+}