diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-06-19 17:19:55 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-06-21 13:01:22 +0300 |
commit | 3911560269fb4121ff05492076fa4fcff449ccd4 (patch) | |
tree | c6a17a427cae4e48d2c4220fba6099ecb7511dbe | |
parent | f0e69b228ac1e9f120d84e05910b3346bc5f01c7 (diff) |
Add Sigmoid postprocessing for AVX2
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | postprocess.h | 30 | ||||
-rw-r--r-- | test/sigmoid_test.cc | 38 |
3 files changed, 69 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 0d42c23..4948bbd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,7 @@ add_executable(tests test/pipeline_test.cc test/quantize_test.cc test/relu_test.cc + test/sigmoid_test.cc intgemm.cc ) diff --git a/postprocess.h b/postprocess.h index 0855548..2df946d 100644 --- a/postprocess.h +++ b/postprocess.h @@ -220,4 +220,34 @@ public: } }; +/* + * Sigmoid (uses Taylor series approximation of e^x) + */ +class Sigmoid {}; + +template <> +class PostprocessImpl<Sigmoid, CPUType::AVX2> { +public: + using InputRegister = __m256; + using OutputRegister = __m256; + + PostprocessImpl(const Sigmoid& config) {} + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = set1_ps<__m256>(0.f); + static const auto const_one = set1_ps<__m256>(1.f); + + auto x = input; + auto minus_x = sub_ps(const_zero, x); + auto e_x = exp_approx_taylor(x); + auto e_minus_x = exp_approx_taylor(minus_x); + + auto sigmoid_case1 = _mm256_rcp_ps(add_ps(const_one, e_minus_x)); + auto sigmoid_case2 = mul_ps(e_x, _mm256_rcp_ps(add_ps(const_one, e_x))); + + auto nonnegative_x_mask = _mm256_cmp_ps(const_zero, x, _CMP_LT_OS); + return _mm256_blendv_ps(sigmoid_case1, sigmoid_case2, nonnegative_x_mask); + } +}; + } diff --git a/test/sigmoid_test.cc b/test/sigmoid_test.cc new file mode 100644 index 0000000..86f85d4 --- /dev/null +++ b/test/sigmoid_test.cc @@ -0,0 +1,38 @@ +#include "3rd_party/catch.hpp" +#include "postprocess.h" + +#include <numeric> + +#define CHECK_FLOAT(actual, expected, epsilon) \ + do { \ + if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ + else { CHECK((actual) == (expected)); } \ + } while(0) + +namespace intgemm { + +INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + const float error_tolerance = 0.001f; + + __m256 input; + auto raw = reinterpret_cast<float*>(&input); + std::iota(raw, raw + 8, -4); + + auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid()); + auto output = postproc.run(input, 0); + auto raw_output = reinterpret_cast<float*>(&output); + + CHECK_FLOAT(raw_output[0], 0.0179862f, error_tolerance); // input = -4 + CHECK_FLOAT(raw_output[1], 0.0474259f, error_tolerance); // input = -3 + CHECK_FLOAT(raw_output[2], 0.1192029f, error_tolerance); // input = -2 + CHECK_FLOAT(raw_output[3], 0.2689414f, error_tolerance); // input = -1 + CHECK_FLOAT(raw_output[4], 0.5f , error_tolerance); // input = 0 + CHECK_FLOAT(raw_output[5], 0.7310586f, error_tolerance); // input = 1 + CHECK_FLOAT(raw_output[6], 0.8807970f, error_tolerance); // input = 2 + CHECK_FLOAT(raw_output[7], 0.9525740f, error_tolerance); // input = 3 +} + +} |