diff options
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | intrinsics.h | 9 | ||||
-rw-r--r-- | postprocess.h | 41 | ||||
-rw-r--r-- | test/tanh_test.cc | 42 |
4 files changed, 93 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 4948bbd..2061bf2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,6 +37,7 @@ add_executable(tests test/quantize_test.cc test/relu_test.cc test/sigmoid_test.cc + test/tanh_test.cc intgemm.cc ) diff --git a/intrinsics.h b/intrinsics.h index 8e4bc27..9337e35 100644 --- a/intrinsics.h +++ b/intrinsics.h @@ -51,6 +51,9 @@ INTGEMM_SSE2 static inline __m128i cvtps_epi32(__m128 arg) { INTGEMM_SSE2 static inline __m128i cvttps_epi32(__m128 a) { return _mm_cvttps_epi32(a); } +INTGEMM_SSE2 static inline __m128 div_ps(__m128 a, __m128 b) { + return _mm_div_ps(a, b); +} /* * Missing i32gather_ps for SSE2 */ @@ -126,6 +129,9 @@ INTGEMM_AVX2 static inline __m256i cvtps_epi32(__m256 arg) { INTGEMM_AVX2 static inline __m256i cvttps_epi32(__m256 a) { return _mm256_cvttps_epi32(a); } +INTGEMM_AVX2 static inline __m256 div_ps(__m256 a, __m256 b) { + return _mm256_div_ps(a, b); +} INTGEMM_AVX2 static inline __m256 i32gather_ps(float const *base_addr, __m256i vindex, const int scale) { return _mm256_i32gather_ps(base_addr, vindex, scale); } @@ -203,6 +209,9 @@ INTGEMM_AVX512BW static inline __m512i cvtps_epi32(__m512 arg) { INTGEMM_AVX512BW static inline __m512i cvttps_epi32(__m512 a) { return _mm512_cvttps_epi32(a); } +INTGEMM_AVX512BW static inline __m512 div_ps(__m512 a, __m512 b) { + return _mm512_div_ps(a, b); +} INTGEMM_AVX512BW static inline __m512 i32gather_ps(float const *base_addr, __m512i vindex, const int scale) { return _mm512_i32gather_ps(vindex, base_addr, scale); } diff --git a/postprocess.h b/postprocess.h index 2df946d..ad9c290 100644 --- a/postprocess.h +++ b/postprocess.h @@ -250,4 +250,45 @@ public: } }; +/* + * Tanh (uses Taylor series approximation of e^x) + */ +class Tanh {}; + +template <> +class PostprocessImpl<Tanh, CPUType::AVX2> { +public: + using InputRegister = __m256; + using OutputRegister = __m256; + + PostprocessImpl(const Tanh& config) {} + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { + const static auto const_zero = setzero_ps<__m256>(); + + auto e_x = exp_approx_taylor(input); + auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input)); + + return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x)); + } +}; + +template <> +class PostprocessImpl<Tanh, CPUType::AVX512BW> { +public: + using InputRegister = __m512; + using OutputRegister = __m512; + + PostprocessImpl(const Tanh& config) {} + + INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { + const static auto const_zero = setzero_ps<__m512>(); + + auto e_x = exp_approx_taylor(input); + auto e_minus_x = exp_approx_taylor(sub_ps(const_zero, input)); + + return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x)); + } +}; + } diff --git a/test/tanh_test.cc b/test/tanh_test.cc new file mode 100644 index 0000000..72e2555 --- /dev/null +++ b/test/tanh_test.cc @@ -0,0 +1,42 @@ +#include "3rd_party/catch.hpp" +#include "postprocess.h" + +#include <numeric> + +#define CHECK_FLOAT(actual, expected, epsilon) \ + do { \ + if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ + else { CHECK((actual) == (expected)); } \ + } while(0) + +namespace intgemm { + +INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + const float error_tolerance = 0.001f; + + __m256 input; + + { // fill + auto raw = reinterpret_cast<float*>(&input); + int n = -4; + std::generate(raw, raw + 8, [&n] () { return n++ / 4.f; }); + } + + auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh()); + auto output = postproc.run(input, 0); + auto raw_output = reinterpret_cast<float*>(&output); + + CHECK_FLOAT(raw_output[0], -0.7615942f, error_tolerance); // input = -1 + CHECK_FLOAT(raw_output[1], -0.6351490f, error_tolerance); // input = -0.75 + CHECK_FLOAT(raw_output[2], -0.4621172f, error_tolerance); // input = -0.5 + CHECK_FLOAT(raw_output[3], -0.2449187f, error_tolerance); // input = -0.25 + CHECK_FLOAT(raw_output[4], 0.0f , error_tolerance); // input = 0 + CHECK_FLOAT(raw_output[5], 0.2449187f, error_tolerance); // input = 0.25 + CHECK_FLOAT(raw_output[6], 0.4621172f, error_tolerance); // input = 0.5 + CHECK_FLOAT(raw_output[7], 0.6351490f, error_tolerance); // input = 0.75 +} + +} |