Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-21 20:06:11 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-25 15:34:05 +0300
commitbdfc834841b1e7c30a95ad4f8c1c665acadcf39b (patch)
tree292f20bedea948b5963ee0163a4ae06bee8b3e15 /test/postprocess/tanh_test.cc
parent968713e30ef28b9422b445733080c33641722e33 (diff)
Move postprocess tests to subdir
Diffstat (limited to 'test/postprocess/tanh_test.cc')
-rw-r--r--test/postprocess/tanh_test.cc39
1 files changed, 39 insertions, 0 deletions
diff --git a/test/postprocess/tanh_test.cc b/test/postprocess/tanh_test.cc
new file mode 100644
index 0000000..54c34fd
--- /dev/null
+++ b/test/postprocess/tanh_test.cc
@@ -0,0 +1,39 @@
+#include "3rd_party/catch.hpp"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+#define CHECK_FLOAT(actual, expected, epsilon) \
+ do { \
+ if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
+ else { CHECK((actual) == (expected)); } \
+ } while(0)
+
+namespace intgemm {
+
+INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const float error_tolerance = 0.001f;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
+
+ std::generate(input.begin(), input.end(), [] () { static int n = -4; return n++ / 4.f; });
+
+ auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh());
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK_FLOAT(output[0], -0.7615942f, error_tolerance); // input = -1
+ CHECK_FLOAT(output[1], -0.6351490f, error_tolerance); // input = -0.75
+ CHECK_FLOAT(output[2], -0.4621172f, error_tolerance); // input = -0.5
+ CHECK_FLOAT(output[3], -0.2449187f, error_tolerance); // input = -0.25
+ CHECK_FLOAT(output[4], 0.0f , error_tolerance); // input = 0
+ CHECK_FLOAT(output[5], 0.2449187f, error_tolerance); // input = 0.25
+ CHECK_FLOAT(output[6], 0.4621172f, error_tolerance); // input = 0.5
+ CHECK_FLOAT(output[7], 0.6351490f, error_tolerance); // input = 0.75
+}
+
+}