diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-06-26 14:15:26 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-26 14:15:26 +0300 |
commit | 53bb9cd73bf00e152d2a2503c372949c395a9c98 (patch) | |
tree | 995268a6eb12c2b821ae8aa0f50911bc141fd0e9 | |
parent | 6bf212836130ba2c59dd845946051deeef4ad09f (diff) | |
parent | c6c11b1ba445db9ebcdd411c688d03c1270d45b5 (diff) |
Merge pull request #18 from kpu/add-missing-unit-tests
Add missing unit tests
-rw-r--r-- | CMakeLists.txt | 12 | ||||
-rw-r--r-- | postprocess.h | 71 | ||||
-rw-r--r-- | postprocess_pipeline.h | 4 | ||||
-rw-r--r-- | test/multiply_test.cc | 27 | ||||
-rw-r--r-- | test/postprocess/add_bias_test.cc | 95 | ||||
-rw-r--r-- | test/postprocess/pipeline_test.cc (renamed from test/pipeline_test.cc) | 2 | ||||
-rw-r--r-- | test/postprocess/relu_test.cc (renamed from test/relu_test.cc) | 2 | ||||
-rw-r--r-- | test/postprocess/sigmoid_test.cc | 33 | ||||
-rw-r--r-- | test/postprocess/tanh_test.cc | 33 | ||||
-rw-r--r-- | test/postprocess/unquantize_test.cc | 88 | ||||
-rw-r--r-- | test/quantize_test.cc | 12 | ||||
-rw-r--r-- | test/sigmoid_test.cc | 39 | ||||
-rw-r--r-- | test/tanh_test.cc | 39 | ||||
-rw-r--r-- | test/test.cc | 6 | ||||
-rw-r--r-- | test/test.h | 12 | ||||
-rw-r--r-- | test/utils_test.cc | 38 |
16 files changed, 352 insertions, 161 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 2061bf2..31b7b53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,11 +33,15 @@ endforeach() include_directories(.) add_executable(tests test/multiply_test.cc - test/pipeline_test.cc + test/postprocess/add_bias_test.cc + test/postprocess/pipeline_test.cc + test/postprocess/relu_test.cc + test/postprocess/sigmoid_test.cc + test/postprocess/tanh_test.cc + test/postprocess/unquantize_test.cc test/quantize_test.cc - test/relu_test.cc - test/sigmoid_test.cc - test/tanh_test.cc + test/test.cc + test/utils_test.cc intgemm.cc ) diff --git a/postprocess.h b/postprocess.h index ad9c290..c4dd7ae 100644 --- a/postprocess.h +++ b/postprocess.h @@ -56,6 +56,8 @@ private: __m256 unquantize_multiplier; }; +#ifndef INTGEMM_NO_AVX512 + template <> class PostprocessImpl<Unquantize, CPUType::AVX512BW> { public: @@ -74,49 +76,7 @@ private: __m512 unquantize_multiplier; }; -/* - * Identity - */ -class Identity {}; - -template <> -class PostprocessImpl<Identity, CPUType::SSE2> { -public: - using InputRegister = RegisterPair128i; - using OutputRegister = RegisterPair128i; - - PostprocessImpl(const Identity& config) {} - - INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { - return input; - } -}; - -template <> -class PostprocessImpl<Identity, CPUType::AVX2> { -public: - using InputRegister = __m256i; - using OutputRegister = __m256i; - - PostprocessImpl(const Identity& config) {} - - INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { - return input; - } -}; - -template <> -class PostprocessImpl<Identity, CPUType::AVX512BW> { -public: - using InputRegister = __m512i; - using OutputRegister = __m512i; - - PostprocessImpl(const Identity& config) {} - - INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { - return input; - } -}; +#endif /* * Add a bias term @@ -167,6 +127,27 @@ private: const AddBias config; }; +#ifndef INTGEMM_NO_AVX512 + +template <> +class PostprocessImpl<AddBias, CPUType::AVX512BW> { +public: + using InputRegister = __m512; + using OutputRegister = __m512; + + PostprocessImpl(const AddBias& config) : config(config) {} + + INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { + auto bias_term = *reinterpret_cast<const __m512*>(config.bias + (offset % config.length)); + return add_ps(input, bias_term); + } + +private: + const AddBias config; +}; + +#endif + /* * ReLU */ @@ -206,6 +187,8 @@ public: } }; +#ifndef INTGEMM_NO_AVX512 + template <> class PostprocessImpl<ReLU, CPUType::AVX512BW> { public: @@ -291,4 +274,6 @@ public: } }; +#endif + } diff --git a/postprocess_pipeline.h b/postprocess_pipeline.h index ad26ac5..361ff2b 100644 --- a/postprocess_pipeline.h +++ b/postprocess_pipeline.h @@ -12,8 +12,8 @@ template <typename... Stages> using PostprocessPipeline = std::tuple<Stages...>; template <typename... Stages> -constexpr std::tuple<Stages...> CreatePostprocessPipeline(const Stages&... stages) { - return std::make_tuple(stages...); +constexpr std::tuple<Stages...> CreatePostprocessPipeline(Stages&&... stages) { + return std::make_tuple(std::forward<Stages>(stages)...); } template <typename Postprocess, CPUType CpuType> diff --git a/test/multiply_test.cc b/test/multiply_test.cc index f88a73a..93d7127 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -1,22 +1,16 @@ +#include "test/test.h" #include "aligned.h" #include "interleave.h" #include "intgemm.h" #include "multiply.h" #include "postprocess.h" -#define CATCH_CONFIG_RUNNER -#include "3rd_party/catch.hpp" -#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while((void)0, 0) -#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while((void)0, 0) -#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while((void)0, 0) -#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while((void)0, 0) - #include <algorithm> #include <cassert> #include <cmath> -#include <cstring> #include <cstdio> #include <cstdlib> +#include <cstring> #include <iomanip> #include <iostream> #include <memory> @@ -554,20 +548,3 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { #endif } // namespace intgemm - -int main(int argc, char ** argv) { - return Catch::Session().run(argc, argv); -} - -/* - // Top matrix sizes from Marian - TestBoth(8, 256, 256); - TestBoth(8, 2048, 256); - TestBoth(8, 2048, 256); - TestBoth(320, 256, 256); - TestBoth(472, 256, 256); - TestBoth(248, 256, 256); - TestBoth(200, 256, 256); - return 0; -} -*/ diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc new file mode 100644 index 0000000..5e893ea --- /dev/null +++ b/test/postprocess/add_bias_test.cc @@ -0,0 +1,95 @@ +#include "test/test.h" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_SSE2 TEST_CASE("AddBias SSE2",) { + if (kCPU < CPUType::SSE2) + return; + + AlignedVector<float> input(8); + AlignedVector<float> bias(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -2); + std::iota(bias.begin(), bias.end(), 0); + + auto postproc = PostprocessImpl<AddBias, CPUType::SSE2>(AddBias(bias.begin(), bias.size())); + auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0); + output.as<__m128>()[0] = output_tmp.pack0123; + output.as<__m128>()[1] = output_tmp.pack4567; + + CHECK(output[0] == -2.f); // input = -2, bias = 0 + CHECK(output[1] == 0.f); // input = -1, bias = 1 + CHECK(output[2] == 2.f); // input = 0, bias = 2 + CHECK(output[3] == 4.f); // input = 1, bias = 3 + CHECK(output[4] == 6.f); // input = 2, bias = 4 + CHECK(output[5] == 8.f); // input = 3, bias = 5 + CHECK(output[6] == 10.f); // input = 4, bias = 6 + CHECK(output[7] == 12.f); // input = 5, bias = 7 +} + +INTGEMM_AVX2 TEST_CASE("AddBias AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + AlignedVector<float> input(8); + AlignedVector<float> bias(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -4); + std::iota(bias.begin(), bias.end(), 0); + + auto postproc = PostprocessImpl<AddBias, CPUType::AVX2>(AddBias(bias.begin(), bias.size())); + *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); + + CHECK(output[0] == -4.f); // input = -4, bias = 0 + CHECK(output[1] == -2.f); // input = -3, bias = 1 + CHECK(output[2] == 0.f); // input = -2, bias = 2 + CHECK(output[3] == 2.f); // input = -1, bias = 3 + CHECK(output[4] == 4.f); // input = 0, bias = 4 + CHECK(output[5] == 6.f); // input = 1, bias = 5 + CHECK(output[6] == 8.f); // input = 2, bias = 6 + CHECK(output[7] == 10.f); // input = 3, bias = 7 +} + +#ifndef INTGEMM_NO_AVX512 + +INTGEMM_AVX512BW TEST_CASE("AddBias AVX512",) { + if (kCPU < CPUType::AVX512BW) + return; + + AlignedVector<float> input(16); + AlignedVector<float> bias(16); + AlignedVector<float> output(16); + + std::iota(input.begin(), input.end(), -8); + std::iota(bias.begin(), bias.end(), 0); + + auto postproc = PostprocessImpl<AddBias, CPUType::AVX512BW>(AddBias(bias.begin(), bias.size())); + *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0); + + CHECK(output[0] == -8.f); // input = -8, bias = 0 + CHECK(output[1] == -6.f); // input = -7, bias = 1 + CHECK(output[2] == -4.f); // input = -6, bias = 2 + CHECK(output[3] == -2.f); // input = -5, bias = 3 + CHECK(output[4] == 0.f); // input = -4, bias = 4 + CHECK(output[5] == 2.f); // input = -3, bias = 5 + CHECK(output[6] == 4.f); // input = -2, bias = 6 + CHECK(output[7] == 6.f); // input = -1, bias = 7 + CHECK(output[8] == 8.f); // input = 0, bias = 8 + CHECK(output[9] == 10.f); // input = 1, bias = 9 + CHECK(output[10] == 12.f); // input = 2, bias = 10 + CHECK(output[11] == 14.f); // input = 3, bias = 11 + CHECK(output[12] == 16.f); // input = 4, bias = 12 + CHECK(output[13] == 18.f); // input = 5, bias = 13 + CHECK(output[14] == 20.f); // input = 6, bias = 14 + CHECK(output[15] == 22.f); // input = 7, bias = 15 +} + +#endif + +} diff --git a/test/pipeline_test.cc b/test/postprocess/pipeline_test.cc index 8d60cff..144ee48 100644 --- a/test/pipeline_test.cc +++ b/test/postprocess/pipeline_test.cc @@ -1,4 +1,4 @@ -#include "3rd_party/catch.hpp" +#include "test/test.h" #include "aligned.h" #include "postprocess.h" diff --git a/test/relu_test.cc b/test/postprocess/relu_test.cc index fda7a2a..e2f2d11 100644 --- a/test/relu_test.cc +++ b/test/postprocess/relu_test.cc @@ -1,4 +1,4 @@ -#include "3rd_party/catch.hpp" +#include "test/test.h" #include "aligned.h" #include "postprocess.h" diff --git a/test/postprocess/sigmoid_test.cc b/test/postprocess/sigmoid_test.cc new file mode 100644 index 0000000..43c713c --- /dev/null +++ b/test/postprocess/sigmoid_test.cc @@ -0,0 +1,33 @@ +#include "test/test.h" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + const float error_tolerance = 0.001f; + + AlignedVector<float> input(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -4); + + auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid()); + *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); + + CHECK_EPS(output[0], 0.0179862f, error_tolerance); // input = -4 + CHECK_EPS(output[1], 0.0474259f, error_tolerance); // input = -3 + CHECK_EPS(output[2], 0.1192029f, error_tolerance); // input = -2 + CHECK_EPS(output[3], 0.2689414f, error_tolerance); // input = -1 + CHECK_EPS(output[4], 0.5f , error_tolerance); // input = 0 + CHECK_EPS(output[5], 0.7310586f, error_tolerance); // input = 1 + CHECK_EPS(output[6], 0.8807970f, error_tolerance); // input = 2 + CHECK_EPS(output[7], 0.9525740f, error_tolerance); // input = 3 +} + +} diff --git a/test/postprocess/tanh_test.cc b/test/postprocess/tanh_test.cc new file mode 100644 index 0000000..f0e4dc2 --- /dev/null +++ b/test/postprocess/tanh_test.cc @@ -0,0 +1,33 @@ +#include "test/test.h" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + const float error_tolerance = 0.001f; + + AlignedVector<float> input(8); + AlignedVector<float> output(8); + + std::generate(input.begin(), input.end(), [] () { static int n = -4; return n++ / 4.f; }); + + auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh()); + *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); + + CHECK_EPS(output[0], -0.7615942f, error_tolerance); // input = -1 + CHECK_EPS(output[1], -0.6351490f, error_tolerance); // input = -0.75 + CHECK_EPS(output[2], -0.4621172f, error_tolerance); // input = -0.5 + CHECK_EPS(output[3], -0.2449187f, error_tolerance); // input = -0.25 + CHECK_EPS(output[4], 0.0f , error_tolerance); // input = 0 + CHECK_EPS(output[5], 0.2449187f, error_tolerance); // input = 0.25 + CHECK_EPS(output[6], 0.4621172f, error_tolerance); // input = 0.5 + CHECK_EPS(output[7], 0.6351490f, error_tolerance); // input = 0.75 +} + +} diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc new file mode 100644 index 0000000..c33b909 --- /dev/null +++ b/test/postprocess/unquantize_test.cc @@ -0,0 +1,88 @@ +#include "test/test.h" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_SSE2 TEST_CASE("Unquantize SSE2",) { + if (kCPU < CPUType::SSE2) + return; + + AlignedVector<int32_t> input(8); + AlignedVector<float> output(8); + std::iota(input.begin(), input.end(), -2); + + auto postproc = PostprocessImpl<Unquantize, CPUType::SSE2>(Unquantize(0.5f)); + auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0); + output.as<__m128>()[0] = output_tmp.pack0123; + output.as<__m128>()[1] = output_tmp.pack4567; + + CHECK(output[0] == -1.0f); // input = -2 + CHECK(output[1] == -0.5f); // input = -1 + CHECK(output[2] == 0.0f); // input = 0 + CHECK(output[3] == 0.5f); // input = 1 + CHECK(output[4] == 1.0f); // input = 2 + CHECK(output[5] == 1.5f); // input = 3 + CHECK(output[6] == 2.0f); // input = 4 + CHECK(output[7] == 2.5f); // input = 5 +} + +INTGEMM_AVX2 TEST_CASE("Unquantize AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + AlignedVector<int32_t> input(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -4); + + auto postproc = PostprocessImpl<Unquantize, CPUType::AVX2>(Unquantize(0.5f)); + *output.as<__m256>() = postproc.run(*input.as<__m256i>(), 0); + + CHECK(output[0] == -2.0f); // input = -4 + CHECK(output[1] == -1.5f); // input = -3 + CHECK(output[2] == -1.0f); // input = -2 + CHECK(output[3] == -0.5f); // input = -1 + CHECK(output[4] == 0.0f); // input = 0 + CHECK(output[5] == 0.5f); // input = 1 + CHECK(output[6] == 1.0f); // input = 2 + CHECK(output[7] == 1.5f); // input = 3 +} + +#ifndef INTGEMM_NO_AVX512 + +INTGEMM_AVX512BW TEST_CASE("Unquantize AVX512",) { + if (kCPU < CPUType::AVX512BW) + return; + + AlignedVector<int32_t> input(16); + AlignedVector<float> output(16); + + std::iota(input.begin(), input.end(), -8); + + auto postproc = PostprocessImpl<Unquantize, CPUType::AVX512BW>(Unquantize(0.5f)); + *output.as<__m512>() = postproc.run(*input.as<__m512i>(), 0); + + CHECK(output[0] == -4.0f); // input = -8 + CHECK(output[1] == -3.5f); // input = -7 + CHECK(output[2] == -3.0f); // input = -6 + CHECK(output[3] == -2.5f); // input = -5 + CHECK(output[4] == -2.0f); // input = -4 + CHECK(output[5] == -1.5f); // input = -3 + CHECK(output[6] == -1.0f); // input = -2 + CHECK(output[7] == -0.5f); // input = -1 + CHECK(output[8] == 0.0f); // input = 0 + CHECK(output[9] == 0.5f); // input = 1 + CHECK(output[10] == 1.0f); // input = 2 + CHECK(output[11] == 1.5f); // input = 3 + CHECK(output[12] == 2.0f); // input = 4 + CHECK(output[13] == 2.5f); // input = 5 + CHECK(output[14] == 3.0f); // input = 6 + CHECK(output[15] == 3.5f); // input = 7 +} + +#endif + +} diff --git a/test/quantize_test.cc b/test/quantize_test.cc index fb866f1..fd7f0a4 100644 --- a/test/quantize_test.cc +++ b/test/quantize_test.cc @@ -1,15 +1,13 @@ -#include "avx512_gemm.h" +#include "test/test.h" +#include "aligned.h" #include "avx2_gemm.h" -#include "ssse3_gemm.h" +#include "avx512_gemm.h" #include "sse2_gemm.h" -#include "aligned.h" - -#include "3rd_party/catch.hpp" +#include "ssse3_gemm.h" #include <cstring> -#include <math.h> - #include <iostream> +#include <math.h> namespace intgemm { namespace { diff --git a/test/sigmoid_test.cc b/test/sigmoid_test.cc deleted file mode 100644 index fc50e37..0000000 --- a/test/sigmoid_test.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include "3rd_party/catch.hpp" -#include "aligned.h" -#include "postprocess.h" - -#include <numeric> - -#define CHECK_FLOAT(actual, expected, epsilon) \ - do { \ - if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ - else { CHECK((actual) == (expected)); } \ - } while(0) - -namespace intgemm { - -INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) { - if (kCPU < CPUType::AVX2) - return; - - const float error_tolerance = 0.001f; - - AlignedVector<float> input(8); - AlignedVector<float> output(8); - - std::iota(input.begin(), input.end(), -4); - - auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid()); - *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); - - CHECK_FLOAT(output[0], 0.0179862f, error_tolerance); // input = -4 - CHECK_FLOAT(output[1], 0.0474259f, error_tolerance); // input = -3 - CHECK_FLOAT(output[2], 0.1192029f, error_tolerance); // input = -2 - CHECK_FLOAT(output[3], 0.2689414f, error_tolerance); // input = -1 - CHECK_FLOAT(output[4], 0.5f , error_tolerance); // input = 0 - CHECK_FLOAT(output[5], 0.7310586f, error_tolerance); // input = 1 - CHECK_FLOAT(output[6], 0.8807970f, error_tolerance); // input = 2 - CHECK_FLOAT(output[7], 0.9525740f, error_tolerance); // input = 3 -} - -} diff --git a/test/tanh_test.cc b/test/tanh_test.cc deleted file mode 100644 index 54c34fd..0000000 --- a/test/tanh_test.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include "3rd_party/catch.hpp" -#include "aligned.h" -#include "postprocess.h" - -#include <numeric> - -#define CHECK_FLOAT(actual, expected, epsilon) \ - do { \ - if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ - else { CHECK((actual) == (expected)); } \ - } while(0) - -namespace intgemm { - -INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) { - if (kCPU < CPUType::AVX2) - return; - - const float error_tolerance = 0.001f; - - AlignedVector<float> input(8); - AlignedVector<float> output(8); - - std::generate(input.begin(), input.end(), [] () { static int n = -4; return n++ / 4.f; }); - - auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh()); - *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); - - CHECK_FLOAT(output[0], -0.7615942f, error_tolerance); // input = -1 - CHECK_FLOAT(output[1], -0.6351490f, error_tolerance); // input = -0.75 - CHECK_FLOAT(output[2], -0.4621172f, error_tolerance); // input = -0.5 - CHECK_FLOAT(output[3], -0.2449187f, error_tolerance); // input = -0.25 - CHECK_FLOAT(output[4], 0.0f , error_tolerance); // input = 0 - CHECK_FLOAT(output[5], 0.2449187f, error_tolerance); // input = 0.25 - CHECK_FLOAT(output[6], 0.4621172f, error_tolerance); // input = 0.5 - CHECK_FLOAT(output[7], 0.6351490f, error_tolerance); // input = 0.75 -} - -} diff --git a/test/test.cc b/test/test.cc new file mode 100644 index 0000000..58c62f8 --- /dev/null +++ b/test/test.cc @@ -0,0 +1,6 @@ +#define CATCH_CONFIG_RUNNER +#include "test/test.h" + +int main(int argc, char ** argv) { + return Catch::Session().run(argc, argv); +} diff --git a/test/test.h b/test/test.h new file mode 100644 index 0000000..572a529 --- /dev/null +++ b/test/test.h @@ -0,0 +1,12 @@ +#include "3rd_party/catch.hpp" + +#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while(0) +#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while(0) +#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while(0) +#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while(0) + +#define CHECK_EPS(actual, expected, epsilon) \ + do { \ + if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \ + else { CHECK((actual) == (expected)); } \ + } while(0) diff --git a/test/utils_test.cc b/test/utils_test.cc new file mode 100644 index 0000000..580a872 --- /dev/null +++ b/test/utils_test.cc @@ -0,0 +1,38 @@ +#include "test/test.h" +#include "utils.h" + +namespace intgemm { +namespace { + +TEST_CASE("Factorial",) { + CHECK(factorial(0) == 1); + CHECK(factorial(1) == 1); + CHECK(factorial(2) == 2); + CHECK(factorial(3) == 6); + CHECK(factorial(4) == 24); + + // Maximum result that fits in unsinged long long + CHECK(factorial(20) == 2432902008176640000); +} + +TEST_CASE("Expi (negative)",) { + const double eps = 0.0000001; + CHECK_EPS(expi(-1), 0.3678794411714423, eps); + CHECK_EPS(expi(-2), 0.1353352832366127, eps); + CHECK_EPS(expi(-10), 0.0000453999297625, eps); +} + +TEST_CASE("Expi (zero)",) { + const double eps = 0.0000001; + CHECK_EPS(expi(0), 1.0, eps); +} + +TEST_CASE("Expi (positive)",) { + const double eps = 0.0000001; + CHECK_EPS(expi(1), 2.7182818284590452, eps); + CHECK_EPS(expi(2), 7.3890560989306502, eps); + CHECK_EPS(expi(10), 22026.4657948067165170, eps); +} + +} +} |