diff options
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | test/multiply_test.cc | 27 | ||||
-rw-r--r-- | test/postprocess/add_bias_test.cc | 2 | ||||
-rw-r--r-- | test/postprocess/pipeline_test.cc | 2 | ||||
-rw-r--r-- | test/postprocess/relu_test.cc | 2 | ||||
-rw-r--r-- | test/postprocess/sigmoid_test.cc | 24 | ||||
-rw-r--r-- | test/postprocess/tanh_test.cc | 24 | ||||
-rw-r--r-- | test/postprocess/unquantize_test.cc | 2 | ||||
-rw-r--r-- | test/quantize_test.cc | 12 | ||||
-rw-r--r-- | test/test.cc | 6 | ||||
-rw-r--r-- | test/test.h | 12 | ||||
-rw-r--r-- | test/utils_test.cc | 22 |
12 files changed, 56 insertions, 80 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 5dde6e5..31b7b53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ add_executable(tests test/postprocess/tanh_test.cc test/postprocess/unquantize_test.cc test/quantize_test.cc + test/test.cc test/utils_test.cc intgemm.cc ) diff --git a/test/multiply_test.cc b/test/multiply_test.cc index f88a73a..93d7127 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -1,22 +1,16 @@ +#include "test/test.h" #include "aligned.h" #include "interleave.h" #include "intgemm.h" #include "multiply.h" #include "postprocess.h" -#define CATCH_CONFIG_RUNNER -#include "3rd_party/catch.hpp" -#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while((void)0, 0) -#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while((void)0, 0) -#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while((void)0, 0) -#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while((void)0, 0) - #include <algorithm> #include <cassert> #include <cmath> -#include <cstring> #include <cstdio> #include <cstdlib> +#include <cstring> #include <iomanip> #include <iostream> #include <memory> @@ -554,20 +548,3 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { #endif } // namespace intgemm - -int main(int argc, char ** argv) { - return Catch::Session().run(argc, argv); -} - -/* - // Top matrix sizes from Marian - TestBoth(8, 256, 256); - TestBoth(8, 2048, 256); - TestBoth(8, 2048, 256); - TestBoth(320, 256, 256); - TestBoth(472, 256, 256); - TestBoth(248, 256, 256); - TestBoth(200, 256, 256); - return 0; -} -*/ diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc index 8d8b46c..5e893ea 100644 --- a/test/postprocess/add_bias_test.cc +++ b/test/postprocess/add_bias_test.cc @@ -1,4 +1,4 @@ -#include "3rd_party/catch.hpp" +#include "test/test.h" #include "aligned.h" #include "postprocess.h" diff --git a/test/postprocess/pipeline_test.cc b/test/postprocess/pipeline_test.cc index 8d60cff..144ee48 100644 --- a/test/postprocess/pipeline_test.cc +++ b/test/postprocess/pipeline_test.cc @@ -1,4 +1,4 @@ -#include "3rd_party/catch.hpp" +#include "test/test.h" #include "aligned.h" #include "postprocess.h" diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc index fda7a2a..e2f2d11 100644 --- a/test/postprocess/relu_test.cc +++ b/test/postprocess/relu_test.cc @@ -1,4 +1,4 @@ -#include "3rd_party/catch.hpp" +#include "test/test.h" #include "aligned.h" #include "postprocess.h" diff --git a/test/postprocess/sigmoid_test.cc b/test/postprocess/sigmoid_test.cc index fc50e37..43c713c 100644 --- a/test/postprocess/sigmoid_test.cc +++ b/test/postprocess/sigmoid_test.cc @@ -1,15 +1,9 @@ -#include "3rd_party/catch.hpp" +#include "test/test.h" #include "aligned.h" #include "postprocess.h" #include <numeric> -#define CHECK_FLOAT(actual, expected, epsilon) \ - do { \ - if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ - else { CHECK((actual) == (expected)); } \ - } while(0) - namespace intgemm { INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) { @@ -26,14 +20,14 @@ INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) { auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid()); *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); - CHECK_FLOAT(output[0], 0.0179862f, error_tolerance); // input = -4 - CHECK_FLOAT(output[1], 0.0474259f, error_tolerance); // input = -3 - CHECK_FLOAT(output[2], 0.1192029f, error_tolerance); // input = -2 - CHECK_FLOAT(output[3], 0.2689414f, error_tolerance); // input = -1 - CHECK_FLOAT(output[4], 0.5f , error_tolerance); // input = 0 - CHECK_FLOAT(output[5], 0.7310586f, error_tolerance); // input = 1 - CHECK_FLOAT(output[6], 0.8807970f, error_tolerance); // input = 2 - CHECK_FLOAT(output[7], 0.9525740f, error_tolerance); // input = 3 + CHECK_EPS(output[0], 0.0179862f, error_tolerance); // input = -4 + CHECK_EPS(output[1], 0.0474259f, error_tolerance); // input = -3 + CHECK_EPS(output[2], 0.1192029f, error_tolerance); // input = -2 + CHECK_EPS(output[3], 0.2689414f, error_tolerance); // input = -1 + CHECK_EPS(output[4], 0.5f , error_tolerance); // input = 0 + CHECK_EPS(output[5], 0.7310586f, error_tolerance); // input = 1 + CHECK_EPS(output[6], 0.8807970f, error_tolerance); // input = 2 + CHECK_EPS(output[7], 0.9525740f, error_tolerance); // input = 3 } } diff --git a/test/postprocess/tanh_test.cc b/test/postprocess/tanh_test.cc index 54c34fd..f0e4dc2 100644 --- a/test/postprocess/tanh_test.cc +++ b/test/postprocess/tanh_test.cc @@ -1,15 +1,9 @@ -#include "3rd_party/catch.hpp" +#include "test/test.h" #include "aligned.h" #include "postprocess.h" #include <numeric> -#define CHECK_FLOAT(actual, expected, epsilon) \ - do { \ - if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ - else { CHECK((actual) == (expected)); } \ - } while(0) - namespace intgemm { INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) { @@ -26,14 +20,14 @@ INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) { auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh()); *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); - CHECK_FLOAT(output[0], -0.7615942f, error_tolerance); // input = -1 - CHECK_FLOAT(output[1], -0.6351490f, error_tolerance); // input = -0.75 - CHECK_FLOAT(output[2], -0.4621172f, error_tolerance); // input = -0.5 - CHECK_FLOAT(output[3], -0.2449187f, error_tolerance); // input = -0.25 - CHECK_FLOAT(output[4], 0.0f , error_tolerance); // input = 0 - CHECK_FLOAT(output[5], 0.2449187f, error_tolerance); // input = 0.25 - CHECK_FLOAT(output[6], 0.4621172f, error_tolerance); // input = 0.5 - CHECK_FLOAT(output[7], 0.6351490f, error_tolerance); // input = 0.75 + CHECK_EPS(output[0], -0.7615942f, error_tolerance); // input = -1 + CHECK_EPS(output[1], -0.6351490f, error_tolerance); // input = -0.75 + CHECK_EPS(output[2], -0.4621172f, error_tolerance); // input = -0.5 + CHECK_EPS(output[3], -0.2449187f, error_tolerance); // input = -0.25 + CHECK_EPS(output[4], 0.0f , error_tolerance); // input = 0 + CHECK_EPS(output[5], 0.2449187f, error_tolerance); // input = 0.25 + CHECK_EPS(output[6], 0.4621172f, error_tolerance); // input = 0.5 + CHECK_EPS(output[7], 0.6351490f, error_tolerance); // input = 0.75 } } diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc index c169289..c33b909 100644 --- a/test/postprocess/unquantize_test.cc +++ b/test/postprocess/unquantize_test.cc @@ -1,4 +1,4 @@ -#include "3rd_party/catch.hpp" +#include "test/test.h" #include "aligned.h" #include "postprocess.h" diff --git a/test/quantize_test.cc b/test/quantize_test.cc index fb866f1..fd7f0a4 100644 --- a/test/quantize_test.cc +++ b/test/quantize_test.cc @@ -1,15 +1,13 @@ -#include "avx512_gemm.h" +#include "test/test.h" +#include "aligned.h" #include "avx2_gemm.h" -#include "ssse3_gemm.h" +#include "avx512_gemm.h" #include "sse2_gemm.h" -#include "aligned.h" - -#include "3rd_party/catch.hpp" +#include "ssse3_gemm.h" #include <cstring> -#include <math.h> - #include <iostream> +#include <math.h> namespace intgemm { namespace { diff --git a/test/test.cc b/test/test.cc new file mode 100644 index 0000000..58c62f8 --- /dev/null +++ b/test/test.cc @@ -0,0 +1,6 @@ +#define CATCH_CONFIG_RUNNER +#include "test/test.h" + +int main(int argc, char ** argv) { + return Catch::Session().run(argc, argv); +} diff --git a/test/test.h b/test/test.h new file mode 100644 index 0000000..572a529 --- /dev/null +++ b/test/test.h @@ -0,0 +1,12 @@ +#include "3rd_party/catch.hpp" + +#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while(0) +#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while(0) +#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while(0) +#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while(0) + +#define CHECK_EPS(actual, expected, epsilon) \ + do { \ + if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ + else { CHECK((actual) == (expected)); } \ + } while(0) diff --git a/test/utils_test.cc b/test/utils_test.cc index 4391e61..580a872 100644 --- a/test/utils_test.cc +++ b/test/utils_test.cc @@ -1,12 +1,6 @@ -#include "3rd_party/catch.hpp" +#include "test/test.h" #include "utils.h" -#define CHECK_DOUBLE(actual, expected, epsilon) \ - do { \ - if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ - else { CHECK((actual) == (expected)); } \ - } while(0) - namespace intgemm { namespace { @@ -23,21 +17,21 @@ TEST_CASE("Factorial",) { TEST_CASE("Expi (negative)",) { const double eps = 0.0000001; - CHECK_DOUBLE(expi(-1), 0.3678794411714423, eps); - CHECK_DOUBLE(expi(-2), 0.1353352832366127, eps); - CHECK_DOUBLE(expi(-10), 0.0000453999297625, eps); + CHECK_EPS(expi(-1), 0.3678794411714423, eps); + CHECK_EPS(expi(-2), 0.1353352832366127, eps); + CHECK_EPS(expi(-10), 0.0000453999297625, eps); } TEST_CASE("Expi (zero)",) { const double eps = 0.0000001; - CHECK_DOUBLE(expi(0), 1.0, eps); + CHECK_EPS(expi(0), 1.0, eps); } TEST_CASE("Expi (positive)",) { const double eps = 0.0000001; - CHECK_DOUBLE(expi(1), 2.7182818284590452, eps); - CHECK_DOUBLE(expi(2), 7.3890560989306502, eps); - CHECK_DOUBLE(expi(10), 22026.4657948067165170, eps); + CHECK_EPS(expi(1), 2.7182818284590452, eps); + CHECK_EPS(expi(2), 7.3890560989306502, eps); + CHECK_EPS(expi(10), 22026.4657948067165170, eps); } } |