Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt1
-rw-r--r--test/multiply_test.cc27
-rw-r--r--test/postprocess/add_bias_test.cc2
-rw-r--r--test/postprocess/pipeline_test.cc2
-rw-r--r--test/postprocess/relu_test.cc2
-rw-r--r--test/postprocess/sigmoid_test.cc24
-rw-r--r--test/postprocess/tanh_test.cc24
-rw-r--r--test/postprocess/unquantize_test.cc2
-rw-r--r--test/quantize_test.cc12
-rw-r--r--test/test.cc6
-rw-r--r--test/test.h12
-rw-r--r--test/utils_test.cc22
12 files changed, 56 insertions, 80 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5dde6e5..31b7b53 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -40,6 +40,7 @@ add_executable(tests
test/postprocess/tanh_test.cc
test/postprocess/unquantize_test.cc
test/quantize_test.cc
+ test/test.cc
test/utils_test.cc
intgemm.cc
)
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index f88a73a..93d7127 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -1,22 +1,16 @@
+#include "test/test.h"
#include "aligned.h"
#include "interleave.h"
#include "intgemm.h"
#include "multiply.h"
#include "postprocess.h"
-#define CATCH_CONFIG_RUNNER
-#include "3rd_party/catch.hpp"
-#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while((void)0, 0)
-#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while((void)0, 0)
-#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while((void)0, 0)
-#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while((void)0, 0)
-
#include <algorithm>
#include <cassert>
#include <cmath>
-#include <cstring>
#include <cstdio>
#include <cstdlib>
+#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
@@ -554,20 +548,3 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
#endif
} // namespace intgemm
-
-int main(int argc, char ** argv) {
- return Catch::Session().run(argc, argv);
-}
-
-/*
- // Top matrix sizes from Marian
- TestBoth(8, 256, 256);
- TestBoth(8, 2048, 256);
- TestBoth(8, 2048, 256);
- TestBoth(320, 256, 256);
- TestBoth(472, 256, 256);
- TestBoth(248, 256, 256);
- TestBoth(200, 256, 256);
- return 0;
-}
-*/
diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc
index 8d8b46c..5e893ea 100644
--- a/test/postprocess/add_bias_test.cc
+++ b/test/postprocess/add_bias_test.cc
@@ -1,4 +1,4 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
diff --git a/test/postprocess/pipeline_test.cc b/test/postprocess/pipeline_test.cc
index 8d60cff..144ee48 100644
--- a/test/postprocess/pipeline_test.cc
+++ b/test/postprocess/pipeline_test.cc
@@ -1,4 +1,4 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
index fda7a2a..e2f2d11 100644
--- a/test/postprocess/relu_test.cc
+++ b/test/postprocess/relu_test.cc
@@ -1,4 +1,4 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
diff --git a/test/postprocess/sigmoid_test.cc b/test/postprocess/sigmoid_test.cc
index fc50e37..43c713c 100644
--- a/test/postprocess/sigmoid_test.cc
+++ b/test/postprocess/sigmoid_test.cc
@@ -1,15 +1,9 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
#include <numeric>
-#define CHECK_FLOAT(actual, expected, epsilon) \
- do { \
- if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
- else { CHECK((actual) == (expected)); } \
- } while(0)
-
namespace intgemm {
INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) {
@@ -26,14 +20,14 @@ INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) {
auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid());
*output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
- CHECK_FLOAT(output[0], 0.0179862f, error_tolerance); // input = -4
- CHECK_FLOAT(output[1], 0.0474259f, error_tolerance); // input = -3
- CHECK_FLOAT(output[2], 0.1192029f, error_tolerance); // input = -2
- CHECK_FLOAT(output[3], 0.2689414f, error_tolerance); // input = -1
- CHECK_FLOAT(output[4], 0.5f , error_tolerance); // input = 0
- CHECK_FLOAT(output[5], 0.7310586f, error_tolerance); // input = 1
- CHECK_FLOAT(output[6], 0.8807970f, error_tolerance); // input = 2
- CHECK_FLOAT(output[7], 0.9525740f, error_tolerance); // input = 3
+ CHECK_EPS(output[0], 0.0179862f, error_tolerance); // input = -4
+ CHECK_EPS(output[1], 0.0474259f, error_tolerance); // input = -3
+ CHECK_EPS(output[2], 0.1192029f, error_tolerance); // input = -2
+ CHECK_EPS(output[3], 0.2689414f, error_tolerance); // input = -1
+ CHECK_EPS(output[4], 0.5f , error_tolerance); // input = 0
+ CHECK_EPS(output[5], 0.7310586f, error_tolerance); // input = 1
+ CHECK_EPS(output[6], 0.8807970f, error_tolerance); // input = 2
+ CHECK_EPS(output[7], 0.9525740f, error_tolerance); // input = 3
}
}
diff --git a/test/postprocess/tanh_test.cc b/test/postprocess/tanh_test.cc
index 54c34fd..f0e4dc2 100644
--- a/test/postprocess/tanh_test.cc
+++ b/test/postprocess/tanh_test.cc
@@ -1,15 +1,9 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
#include <numeric>
-#define CHECK_FLOAT(actual, expected, epsilon) \
- do { \
- if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
- else { CHECK((actual) == (expected)); } \
- } while(0)
-
namespace intgemm {
INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
@@ -26,14 +20,14 @@ INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh());
*output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
- CHECK_FLOAT(output[0], -0.7615942f, error_tolerance); // input = -1
- CHECK_FLOAT(output[1], -0.6351490f, error_tolerance); // input = -0.75
- CHECK_FLOAT(output[2], -0.4621172f, error_tolerance); // input = -0.5
- CHECK_FLOAT(output[3], -0.2449187f, error_tolerance); // input = -0.25
- CHECK_FLOAT(output[4], 0.0f , error_tolerance); // input = 0
- CHECK_FLOAT(output[5], 0.2449187f, error_tolerance); // input = 0.25
- CHECK_FLOAT(output[6], 0.4621172f, error_tolerance); // input = 0.5
- CHECK_FLOAT(output[7], 0.6351490f, error_tolerance); // input = 0.75
+ CHECK_EPS(output[0], -0.7615942f, error_tolerance); // input = -1
+ CHECK_EPS(output[1], -0.6351490f, error_tolerance); // input = -0.75
+ CHECK_EPS(output[2], -0.4621172f, error_tolerance); // input = -0.5
+ CHECK_EPS(output[3], -0.2449187f, error_tolerance); // input = -0.25
+ CHECK_EPS(output[4], 0.0f , error_tolerance); // input = 0
+ CHECK_EPS(output[5], 0.2449187f, error_tolerance); // input = 0.25
+ CHECK_EPS(output[6], 0.4621172f, error_tolerance); // input = 0.5
+ CHECK_EPS(output[7], 0.6351490f, error_tolerance); // input = 0.75
}
}
diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc
index c169289..c33b909 100644
--- a/test/postprocess/unquantize_test.cc
+++ b/test/postprocess/unquantize_test.cc
@@ -1,4 +1,4 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
diff --git a/test/quantize_test.cc b/test/quantize_test.cc
index fb866f1..fd7f0a4 100644
--- a/test/quantize_test.cc
+++ b/test/quantize_test.cc
@@ -1,15 +1,13 @@
-#include "avx512_gemm.h"
+#include "test/test.h"
+#include "aligned.h"
#include "avx2_gemm.h"
-#include "ssse3_gemm.h"
+#include "avx512_gemm.h"
#include "sse2_gemm.h"
-#include "aligned.h"
-
-#include "3rd_party/catch.hpp"
+#include "ssse3_gemm.h"
#include <cstring>
-#include <math.h>
-
#include <iostream>
+#include <math.h>
namespace intgemm {
namespace {
diff --git a/test/test.cc b/test/test.cc
new file mode 100644
index 0000000..58c62f8
--- /dev/null
+++ b/test/test.cc
@@ -0,0 +1,6 @@
+#define CATCH_CONFIG_RUNNER
+#include "test/test.h"
+
+int main(int argc, char ** argv) {
+ return Catch::Session().run(argc, argv);
+}
diff --git a/test/test.h b/test/test.h
new file mode 100644
index 0000000..572a529
--- /dev/null
+++ b/test/test.h
@@ -0,0 +1,12 @@
+#include "3rd_party/catch.hpp"
+
+#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while(0)
+#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while(0)
+#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while(0)
+#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while(0)
+
+#define CHECK_EPS(actual, expected, epsilon) \
+ do { \
+ if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
+ else { CHECK((actual) == (expected)); } \
+ } while(0)
diff --git a/test/utils_test.cc b/test/utils_test.cc
index 4391e61..580a872 100644
--- a/test/utils_test.cc
+++ b/test/utils_test.cc
@@ -1,12 +1,6 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "utils.h"
-#define CHECK_DOUBLE(actual, expected, epsilon) \
- do { \
- if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
- else { CHECK((actual) == (expected)); } \
- } while(0)
-
namespace intgemm {
namespace {
@@ -23,21 +17,21 @@ TEST_CASE("Factorial",) {
TEST_CASE("Expi (negative)",) {
const double eps = 0.0000001;
- CHECK_DOUBLE(expi(-1), 0.3678794411714423, eps);
- CHECK_DOUBLE(expi(-2), 0.1353352832366127, eps);
- CHECK_DOUBLE(expi(-10), 0.0000453999297625, eps);
+ CHECK_EPS(expi(-1), 0.3678794411714423, eps);
+ CHECK_EPS(expi(-2), 0.1353352832366127, eps);
+ CHECK_EPS(expi(-10), 0.0000453999297625, eps);
}
TEST_CASE("Expi (zero)",) {
const double eps = 0.0000001;
- CHECK_DOUBLE(expi(0), 1.0, eps);
+ CHECK_EPS(expi(0), 1.0, eps);
}
TEST_CASE("Expi (positive)",) {
const double eps = 0.0000001;
- CHECK_DOUBLE(expi(1), 2.7182818284590452, eps);
- CHECK_DOUBLE(expi(2), 7.3890560989306502, eps);
- CHECK_DOUBLE(expi(10), 22026.4657948067165170, eps);
+ CHECK_EPS(expi(1), 2.7182818284590452, eps);
+ CHECK_EPS(expi(2), 7.3890560989306502, eps);
+ CHECK_EPS(expi(10), 22026.4657948067165170, eps);
}
}