Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-19 17:19:55 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-21 13:01:22 +0300
commit3911560269fb4121ff05492076fa4fcff449ccd4 (patch)
treec6a17a427cae4e48d2c4220fba6099ecb7511dbe
parentf0e69b228ac1e9f120d84e05910b3346bc5f01c7 (diff)
Add Sigmoid postprocessing for AVX2
-rw-r--r--CMakeLists.txt1
-rw-r--r--postprocess.h30
-rw-r--r--test/sigmoid_test.cc38
3 files changed, 69 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0d42c23..4948bbd 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -36,6 +36,7 @@ add_executable(tests
test/pipeline_test.cc
test/quantize_test.cc
test/relu_test.cc
+ test/sigmoid_test.cc
intgemm.cc
)
diff --git a/postprocess.h b/postprocess.h
index 0855548..2df946d 100644
--- a/postprocess.h
+++ b/postprocess.h
@@ -220,4 +220,34 @@ public:
}
};
+/*
+ * Sigmoid (uses Taylor series approximation of e^x)
+ */
+class Sigmoid {};
+
+template <>
+class PostprocessImpl<Sigmoid, CPUType::AVX2> {
+public:
+ using InputRegister = __m256;
+ using OutputRegister = __m256;
+
+ PostprocessImpl(const Sigmoid& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = set1_ps<__m256>(0.f);
+ static const auto const_one = set1_ps<__m256>(1.f);
+
+ auto x = input;
+ auto minus_x = sub_ps(const_zero, x);
+ auto e_x = exp_approx_taylor(x);
+ auto e_minus_x = exp_approx_taylor(minus_x);
+
+ auto sigmoid_case1 = _mm256_rcp_ps(add_ps(const_one, e_minus_x));
+ auto sigmoid_case2 = mul_ps(e_x, _mm256_rcp_ps(add_ps(const_one, e_x)));
+
+ auto nonnegative_x_mask = _mm256_cmp_ps(const_zero, x, _CMP_LT_OS);
+ return _mm256_blendv_ps(sigmoid_case1, sigmoid_case2, nonnegative_x_mask);
+ }
+};
+
}
diff --git a/test/sigmoid_test.cc b/test/sigmoid_test.cc
new file mode 100644
index 0000000..86f85d4
--- /dev/null
+++ b/test/sigmoid_test.cc
@@ -0,0 +1,38 @@
+#include "3rd_party/catch.hpp"
+#include "postprocess.h"
+
+#include <numeric>
+
+#define CHECK_FLOAT(actual, expected, epsilon) \
+ do { \
+ if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
+ else { CHECK((actual) == (expected)); } \
+ } while(0)
+
+namespace intgemm {
+
+INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const float error_tolerance = 0.001f;
+
+ __m256 input;
+ auto raw = reinterpret_cast<float*>(&input);
+ std::iota(raw, raw + 8, -4);
+
+ auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid());
+ auto output = postproc.run(input, 0);
+ auto raw_output = reinterpret_cast<float*>(&output);
+
+ CHECK_FLOAT(raw_output[0], 0.0179862f, error_tolerance); // input = -4
+ CHECK_FLOAT(raw_output[1], 0.0474259f, error_tolerance); // input = -3
+ CHECK_FLOAT(raw_output[2], 0.1192029f, error_tolerance); // input = -2
+ CHECK_FLOAT(raw_output[3], 0.2689414f, error_tolerance); // input = -1
+ CHECK_FLOAT(raw_output[4], 0.5f , error_tolerance); // input = 0
+ CHECK_FLOAT(raw_output[5], 0.7310586f, error_tolerance); // input = 1
+ CHECK_FLOAT(raw_output[6], 0.8807970f, error_tolerance); // input = 2
+ CHECK_FLOAT(raw_output[7], 0.9525740f, error_tolerance); // input = 3
+}
+
+}