Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/postprocess')
-rw-r--r--test/postprocess/add_bias_test.cc2
-rw-r--r--test/postprocess/pipeline_test.cc2
-rw-r--r--test/postprocess/relu_test.cc2
-rw-r--r--test/postprocess/sigmoid_test.cc24
-rw-r--r--test/postprocess/tanh_test.cc24
-rw-r--r--test/postprocess/unquantize_test.cc2
6 files changed, 22 insertions, 34 deletions
diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc
index 8d8b46c..5e893ea 100644
--- a/test/postprocess/add_bias_test.cc
+++ b/test/postprocess/add_bias_test.cc
@@ -1,4 +1,4 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
diff --git a/test/postprocess/pipeline_test.cc b/test/postprocess/pipeline_test.cc
index 8d60cff..144ee48 100644
--- a/test/postprocess/pipeline_test.cc
+++ b/test/postprocess/pipeline_test.cc
@@ -1,4 +1,4 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
index fda7a2a..e2f2d11 100644
--- a/test/postprocess/relu_test.cc
+++ b/test/postprocess/relu_test.cc
@@ -1,4 +1,4 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
diff --git a/test/postprocess/sigmoid_test.cc b/test/postprocess/sigmoid_test.cc
index fc50e37..43c713c 100644
--- a/test/postprocess/sigmoid_test.cc
+++ b/test/postprocess/sigmoid_test.cc
@@ -1,15 +1,9 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
#include <numeric>
-#define CHECK_FLOAT(actual, expected, epsilon) \
- do { \
- if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
- else { CHECK((actual) == (expected)); } \
- } while(0)
-
namespace intgemm {
INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) {
@@ -26,14 +20,14 @@ INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) {
auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid());
*output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
- CHECK_FLOAT(output[0], 0.0179862f, error_tolerance); // input = -4
- CHECK_FLOAT(output[1], 0.0474259f, error_tolerance); // input = -3
- CHECK_FLOAT(output[2], 0.1192029f, error_tolerance); // input = -2
- CHECK_FLOAT(output[3], 0.2689414f, error_tolerance); // input = -1
- CHECK_FLOAT(output[4], 0.5f , error_tolerance); // input = 0
- CHECK_FLOAT(output[5], 0.7310586f, error_tolerance); // input = 1
- CHECK_FLOAT(output[6], 0.8807970f, error_tolerance); // input = 2
- CHECK_FLOAT(output[7], 0.9525740f, error_tolerance); // input = 3
+ CHECK_EPS(output[0], 0.0179862f, error_tolerance); // input = -4
+ CHECK_EPS(output[1], 0.0474259f, error_tolerance); // input = -3
+ CHECK_EPS(output[2], 0.1192029f, error_tolerance); // input = -2
+ CHECK_EPS(output[3], 0.2689414f, error_tolerance); // input = -1
+ CHECK_EPS(output[4], 0.5f , error_tolerance); // input = 0
+ CHECK_EPS(output[5], 0.7310586f, error_tolerance); // input = 1
+ CHECK_EPS(output[6], 0.8807970f, error_tolerance); // input = 2
+ CHECK_EPS(output[7], 0.9525740f, error_tolerance); // input = 3
}
}
diff --git a/test/postprocess/tanh_test.cc b/test/postprocess/tanh_test.cc
index 54c34fd..f0e4dc2 100644
--- a/test/postprocess/tanh_test.cc
+++ b/test/postprocess/tanh_test.cc
@@ -1,15 +1,9 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
#include <numeric>
-#define CHECK_FLOAT(actual, expected, epsilon) \
- do { \
- if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
- else { CHECK((actual) == (expected)); } \
- } while(0)
-
namespace intgemm {
INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
@@ -26,14 +20,14 @@ INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh());
*output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
- CHECK_FLOAT(output[0], -0.7615942f, error_tolerance); // input = -1
- CHECK_FLOAT(output[1], -0.6351490f, error_tolerance); // input = -0.75
- CHECK_FLOAT(output[2], -0.4621172f, error_tolerance); // input = -0.5
- CHECK_FLOAT(output[3], -0.2449187f, error_tolerance); // input = -0.25
- CHECK_FLOAT(output[4], 0.0f , error_tolerance); // input = 0
- CHECK_FLOAT(output[5], 0.2449187f, error_tolerance); // input = 0.25
- CHECK_FLOAT(output[6], 0.4621172f, error_tolerance); // input = 0.5
- CHECK_FLOAT(output[7], 0.6351490f, error_tolerance); // input = 0.75
+ CHECK_EPS(output[0], -0.7615942f, error_tolerance); // input = -1
+ CHECK_EPS(output[1], -0.6351490f, error_tolerance); // input = -0.75
+ CHECK_EPS(output[2], -0.4621172f, error_tolerance); // input = -0.5
+ CHECK_EPS(output[3], -0.2449187f, error_tolerance); // input = -0.25
+ CHECK_EPS(output[4], 0.0f , error_tolerance); // input = 0
+ CHECK_EPS(output[5], 0.2449187f, error_tolerance); // input = 0.25
+ CHECK_EPS(output[6], 0.4621172f, error_tolerance); // input = 0.5
+ CHECK_EPS(output[7], 0.6351490f, error_tolerance); // input = 0.75
}
}
diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc
index c169289..c33b909 100644
--- a/test/postprocess/unquantize_test.cc
+++ b/test/postprocess/unquantize_test.cc
@@ -1,4 +1,4 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"