Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-26 14:15:26 +0300
committerGitHub <noreply@github.com>2019-06-26 14:15:26 +0300
commit53bb9cd73bf00e152d2a2503c372949c395a9c98 (patch)
tree995268a6eb12c2b821ae8aa0f50911bc141fd0e9
parent6bf212836130ba2c59dd845946051deeef4ad09f (diff)
parentc6c11b1ba445db9ebcdd411c688d03c1270d45b5 (diff)
Merge pull request #18 from kpu/add-missing-unit-tests
Add missing unit tests
-rw-r--r--CMakeLists.txt12
-rw-r--r--postprocess.h71
-rw-r--r--postprocess_pipeline.h4
-rw-r--r--test/multiply_test.cc27
-rw-r--r--test/postprocess/add_bias_test.cc95
-rw-r--r--test/postprocess/pipeline_test.cc (renamed from test/pipeline_test.cc)2
-rw-r--r--test/postprocess/relu_test.cc (renamed from test/relu_test.cc)2
-rw-r--r--test/postprocess/sigmoid_test.cc33
-rw-r--r--test/postprocess/tanh_test.cc33
-rw-r--r--test/postprocess/unquantize_test.cc88
-rw-r--r--test/quantize_test.cc12
-rw-r--r--test/sigmoid_test.cc39
-rw-r--r--test/tanh_test.cc39
-rw-r--r--test/test.cc6
-rw-r--r--test/test.h12
-rw-r--r--test/utils_test.cc38
16 files changed, 352 insertions, 161 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2061bf2..31b7b53 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -33,11 +33,15 @@ endforeach()
include_directories(.)
add_executable(tests
test/multiply_test.cc
- test/pipeline_test.cc
+ test/postprocess/add_bias_test.cc
+ test/postprocess/pipeline_test.cc
+ test/postprocess/relu_test.cc
+ test/postprocess/sigmoid_test.cc
+ test/postprocess/tanh_test.cc
+ test/postprocess/unquantize_test.cc
test/quantize_test.cc
- test/relu_test.cc
- test/sigmoid_test.cc
- test/tanh_test.cc
+ test/test.cc
+ test/utils_test.cc
intgemm.cc
)
diff --git a/postprocess.h b/postprocess.h
index ad9c290..c4dd7ae 100644
--- a/postprocess.h
+++ b/postprocess.h
@@ -56,6 +56,8 @@ private:
__m256 unquantize_multiplier;
};
+#ifndef INTGEMM_NO_AVX512
+
template <>
class PostprocessImpl<Unquantize, CPUType::AVX512BW> {
public:
@@ -74,49 +76,7 @@ private:
__m512 unquantize_multiplier;
};
-/*
- * Identity
- */
-class Identity {};
-
-template <>
-class PostprocessImpl<Identity, CPUType::SSE2> {
-public:
- using InputRegister = RegisterPair128i;
- using OutputRegister = RegisterPair128i;
-
- PostprocessImpl(const Identity& config) {}
-
- INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
- return input;
- }
-};
-
-template <>
-class PostprocessImpl<Identity, CPUType::AVX2> {
-public:
- using InputRegister = __m256i;
- using OutputRegister = __m256i;
-
- PostprocessImpl(const Identity& config) {}
-
- INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
- return input;
- }
-};
-
-template <>
-class PostprocessImpl<Identity, CPUType::AVX512BW> {
-public:
- using InputRegister = __m512i;
- using OutputRegister = __m512i;
-
- PostprocessImpl(const Identity& config) {}
-
- INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
- return input;
- }
-};
+#endif
/*
* Add a bias term
@@ -167,6 +127,27 @@ private:
const AddBias config;
};
+#ifndef INTGEMM_NO_AVX512
+
+template <>
+class PostprocessImpl<AddBias, CPUType::AVX512BW> {
+public:
+ using InputRegister = __m512;
+ using OutputRegister = __m512;
+
+ PostprocessImpl(const AddBias& config) : config(config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
+ auto bias_term = *reinterpret_cast<const __m512*>(config.bias + (offset % config.length));
+ return add_ps(input, bias_term);
+ }
+
+private:
+ const AddBias config;
+};
+
+#endif
+
/*
* ReLU
*/
@@ -206,6 +187,8 @@ public:
}
};
+#ifndef INTGEMM_NO_AVX512
+
template <>
class PostprocessImpl<ReLU, CPUType::AVX512BW> {
public:
@@ -291,4 +274,6 @@ public:
}
};
+#endif
+
}
diff --git a/postprocess_pipeline.h b/postprocess_pipeline.h
index ad26ac5..361ff2b 100644
--- a/postprocess_pipeline.h
+++ b/postprocess_pipeline.h
@@ -12,8 +12,8 @@ template <typename... Stages>
using PostprocessPipeline = std::tuple<Stages...>;
template <typename... Stages>
-constexpr std::tuple<Stages...> CreatePostprocessPipeline(const Stages&... stages) {
- return std::make_tuple(stages...);
+constexpr std::tuple<Stages...> CreatePostprocessPipeline(Stages&&... stages) {
+ return std::make_tuple(std::forward<Stages>(stages)...);
}
template <typename Postprocess, CPUType CpuType>
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index f88a73a..93d7127 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -1,22 +1,16 @@
+#include "test/test.h"
#include "aligned.h"
#include "interleave.h"
#include "intgemm.h"
#include "multiply.h"
#include "postprocess.h"
-#define CATCH_CONFIG_RUNNER
-#include "3rd_party/catch.hpp"
-#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while((void)0, 0)
-#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while((void)0, 0)
-#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while((void)0, 0)
-#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while((void)0, 0)
-
#include <algorithm>
#include <cassert>
#include <cmath>
-#include <cstring>
#include <cstdio>
#include <cstdlib>
+#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
@@ -554,20 +548,3 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
#endif
} // namespace intgemm
-
-int main(int argc, char ** argv) {
- return Catch::Session().run(argc, argv);
-}
-
-/*
- // Top matrix sizes from Marian
- TestBoth(8, 256, 256);
- TestBoth(8, 2048, 256);
- TestBoth(8, 2048, 256);
- TestBoth(320, 256, 256);
- TestBoth(472, 256, 256);
- TestBoth(248, 256, 256);
- TestBoth(200, 256, 256);
- return 0;
-}
-*/
diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc
new file mode 100644
index 0000000..5e893ea
--- /dev/null
+++ b/test/postprocess/add_bias_test.cc
@@ -0,0 +1,95 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_SSE2 TEST_CASE("AddBias SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> bias(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -2);
+ std::iota(bias.begin(), bias.end(), 0);
+
+ auto postproc = PostprocessImpl<AddBias, CPUType::SSE2>(AddBias(bias.begin(), bias.size()));
+ auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0);
+ output.as<__m128>()[0] = output_tmp.pack0123;
+ output.as<__m128>()[1] = output_tmp.pack4567;
+
+ CHECK(output[0] == -2.f); // input = -2, bias = 0
+ CHECK(output[1] == 0.f); // input = -1, bias = 1
+ CHECK(output[2] == 2.f); // input = 0, bias = 2
+ CHECK(output[3] == 4.f); // input = 1, bias = 3
+ CHECK(output[4] == 6.f); // input = 2, bias = 4
+ CHECK(output[5] == 8.f); // input = 3, bias = 5
+ CHECK(output[6] == 10.f); // input = 4, bias = 6
+ CHECK(output[7] == 12.f); // input = 5, bias = 7
+}
+
+INTGEMM_AVX2 TEST_CASE("AddBias AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> bias(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -4);
+ std::iota(bias.begin(), bias.end(), 0);
+
+ auto postproc = PostprocessImpl<AddBias, CPUType::AVX2>(AddBias(bias.begin(), bias.size()));
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK(output[0] == -4.f); // input = -4, bias = 0
+ CHECK(output[1] == -2.f); // input = -3, bias = 1
+ CHECK(output[2] == 0.f); // input = -2, bias = 2
+ CHECK(output[3] == 2.f); // input = -1, bias = 3
+ CHECK(output[4] == 4.f); // input = 0, bias = 4
+ CHECK(output[5] == 6.f); // input = 1, bias = 5
+ CHECK(output[6] == 8.f); // input = 2, bias = 6
+ CHECK(output[7] == 10.f); // input = 3, bias = 7
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("AddBias AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ AlignedVector<float> input(16);
+ AlignedVector<float> bias(16);
+ AlignedVector<float> output(16);
+
+ std::iota(input.begin(), input.end(), -8);
+ std::iota(bias.begin(), bias.end(), 0);
+
+ auto postproc = PostprocessImpl<AddBias, CPUType::AVX512BW>(AddBias(bias.begin(), bias.size()));
+ *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0);
+
+ CHECK(output[0] == -8.f); // input = -8, bias = 0
+ CHECK(output[1] == -6.f); // input = -7, bias = 1
+ CHECK(output[2] == -4.f); // input = -6, bias = 2
+ CHECK(output[3] == -2.f); // input = -5, bias = 3
+ CHECK(output[4] == 0.f); // input = -4, bias = 4
+ CHECK(output[5] == 2.f); // input = -3, bias = 5
+ CHECK(output[6] == 4.f); // input = -2, bias = 6
+ CHECK(output[7] == 6.f); // input = -1, bias = 7
+ CHECK(output[8] == 8.f); // input = 0, bias = 8
+ CHECK(output[9] == 10.f); // input = 1, bias = 9
+ CHECK(output[10] == 12.f); // input = 2, bias = 10
+ CHECK(output[11] == 14.f); // input = 3, bias = 11
+ CHECK(output[12] == 16.f); // input = 4, bias = 12
+ CHECK(output[13] == 18.f); // input = 5, bias = 13
+ CHECK(output[14] == 20.f); // input = 6, bias = 14
+ CHECK(output[15] == 22.f); // input = 7, bias = 15
+}
+
+#endif
+
+}
diff --git a/test/pipeline_test.cc b/test/postprocess/pipeline_test.cc
index 8d60cff..144ee48 100644
--- a/test/pipeline_test.cc
+++ b/test/postprocess/pipeline_test.cc
@@ -1,4 +1,4 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
diff --git a/test/relu_test.cc b/test/postprocess/relu_test.cc
index fda7a2a..e2f2d11 100644
--- a/test/relu_test.cc
+++ b/test/postprocess/relu_test.cc
@@ -1,4 +1,4 @@
-#include "3rd_party/catch.hpp"
+#include "test/test.h"
#include "aligned.h"
#include "postprocess.h"
diff --git a/test/postprocess/sigmoid_test.cc b/test/postprocess/sigmoid_test.cc
new file mode 100644
index 0000000..43c713c
--- /dev/null
+++ b/test/postprocess/sigmoid_test.cc
@@ -0,0 +1,33 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const float error_tolerance = 0.001f;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -4);
+
+ auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid());
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK_EPS(output[0], 0.0179862f, error_tolerance); // input = -4
+ CHECK_EPS(output[1], 0.0474259f, error_tolerance); // input = -3
+ CHECK_EPS(output[2], 0.1192029f, error_tolerance); // input = -2
+ CHECK_EPS(output[3], 0.2689414f, error_tolerance); // input = -1
+ CHECK_EPS(output[4], 0.5f , error_tolerance); // input = 0
+ CHECK_EPS(output[5], 0.7310586f, error_tolerance); // input = 1
+ CHECK_EPS(output[6], 0.8807970f, error_tolerance); // input = 2
+ CHECK_EPS(output[7], 0.9525740f, error_tolerance); // input = 3
+}
+
+}
diff --git a/test/postprocess/tanh_test.cc b/test/postprocess/tanh_test.cc
new file mode 100644
index 0000000..f0e4dc2
--- /dev/null
+++ b/test/postprocess/tanh_test.cc
@@ -0,0 +1,33 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const float error_tolerance = 0.001f;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
+
+ std::generate(input.begin(), input.end(), [] () { static int n = -4; return n++ / 4.f; });
+
+ auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh());
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK_EPS(output[0], -0.7615942f, error_tolerance); // input = -1
+ CHECK_EPS(output[1], -0.6351490f, error_tolerance); // input = -0.75
+ CHECK_EPS(output[2], -0.4621172f, error_tolerance); // input = -0.5
+ CHECK_EPS(output[3], -0.2449187f, error_tolerance); // input = -0.25
+ CHECK_EPS(output[4], 0.0f , error_tolerance); // input = 0
+ CHECK_EPS(output[5], 0.2449187f, error_tolerance); // input = 0.25
+ CHECK_EPS(output[6], 0.4621172f, error_tolerance); // input = 0.5
+ CHECK_EPS(output[7], 0.6351490f, error_tolerance); // input = 0.75
+}
+
+}
diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc
new file mode 100644
index 0000000..c33b909
--- /dev/null
+++ b/test/postprocess/unquantize_test.cc
@@ -0,0 +1,88 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_SSE2 TEST_CASE("Unquantize SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ AlignedVector<int32_t> input(8);
+ AlignedVector<float> output(8);
+ std::iota(input.begin(), input.end(), -2);
+
+ auto postproc = PostprocessImpl<Unquantize, CPUType::SSE2>(Unquantize(0.5f));
+ auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
+ output.as<__m128>()[0] = output_tmp.pack0123;
+ output.as<__m128>()[1] = output_tmp.pack4567;
+
+ CHECK(output[0] == -1.0f); // input = -2
+ CHECK(output[1] == -0.5f); // input = -1
+ CHECK(output[2] == 0.0f); // input = 0
+ CHECK(output[3] == 0.5f); // input = 1
+ CHECK(output[4] == 1.0f); // input = 2
+ CHECK(output[5] == 1.5f); // input = 3
+ CHECK(output[6] == 2.0f); // input = 4
+ CHECK(output[7] == 2.5f); // input = 5
+}
+
+INTGEMM_AVX2 TEST_CASE("Unquantize AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ AlignedVector<int32_t> input(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -4);
+
+ auto postproc = PostprocessImpl<Unquantize, CPUType::AVX2>(Unquantize(0.5f));
+ *output.as<__m256>() = postproc.run(*input.as<__m256i>(), 0);
+
+ CHECK(output[0] == -2.0f); // input = -4
+ CHECK(output[1] == -1.5f); // input = -3
+ CHECK(output[2] == -1.0f); // input = -2
+ CHECK(output[3] == -0.5f); // input = -1
+ CHECK(output[4] == 0.0f); // input = 0
+ CHECK(output[5] == 0.5f); // input = 1
+ CHECK(output[6] == 1.0f); // input = 2
+ CHECK(output[7] == 1.5f); // input = 3
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("Unquantize AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ AlignedVector<int32_t> input(16);
+ AlignedVector<float> output(16);
+
+ std::iota(input.begin(), input.end(), -8);
+
+ auto postproc = PostprocessImpl<Unquantize, CPUType::AVX512BW>(Unquantize(0.5f));
+ *output.as<__m512>() = postproc.run(*input.as<__m512i>(), 0);
+
+ CHECK(output[0] == -4.0f); // input = -8
+ CHECK(output[1] == -3.5f); // input = -7
+ CHECK(output[2] == -3.0f); // input = -6
+ CHECK(output[3] == -2.5f); // input = -5
+ CHECK(output[4] == -2.0f); // input = -4
+ CHECK(output[5] == -1.5f); // input = -3
+ CHECK(output[6] == -1.0f); // input = -2
+ CHECK(output[7] == -0.5f); // input = -1
+ CHECK(output[8] == 0.0f); // input = 0
+ CHECK(output[9] == 0.5f); // input = 1
+ CHECK(output[10] == 1.0f); // input = 2
+ CHECK(output[11] == 1.5f); // input = 3
+ CHECK(output[12] == 2.0f); // input = 4
+ CHECK(output[13] == 2.5f); // input = 5
+ CHECK(output[14] == 3.0f); // input = 6
+ CHECK(output[15] == 3.5f); // input = 7
+}
+
+#endif
+
+}
diff --git a/test/quantize_test.cc b/test/quantize_test.cc
index fb866f1..fd7f0a4 100644
--- a/test/quantize_test.cc
+++ b/test/quantize_test.cc
@@ -1,15 +1,13 @@
-#include "avx512_gemm.h"
+#include "test/test.h"
+#include "aligned.h"
#include "avx2_gemm.h"
-#include "ssse3_gemm.h"
+#include "avx512_gemm.h"
#include "sse2_gemm.h"
-#include "aligned.h"
-
-#include "3rd_party/catch.hpp"
+#include "ssse3_gemm.h"
#include <cstring>
-#include <math.h>
-
#include <iostream>
+#include <math.h>
namespace intgemm {
namespace {
diff --git a/test/sigmoid_test.cc b/test/sigmoid_test.cc
deleted file mode 100644
index fc50e37..0000000
--- a/test/sigmoid_test.cc
+++ /dev/null
@@ -1,39 +0,0 @@
-#include "3rd_party/catch.hpp"
-#include "aligned.h"
-#include "postprocess.h"
-
-#include <numeric>
-
-#define CHECK_FLOAT(actual, expected, epsilon) \
- do { \
- if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
- else { CHECK((actual) == (expected)); } \
- } while(0)
-
-namespace intgemm {
-
-INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) {
- if (kCPU < CPUType::AVX2)
- return;
-
- const float error_tolerance = 0.001f;
-
- AlignedVector<float> input(8);
- AlignedVector<float> output(8);
-
- std::iota(input.begin(), input.end(), -4);
-
- auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid());
- *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
-
- CHECK_FLOAT(output[0], 0.0179862f, error_tolerance); // input = -4
- CHECK_FLOAT(output[1], 0.0474259f, error_tolerance); // input = -3
- CHECK_FLOAT(output[2], 0.1192029f, error_tolerance); // input = -2
- CHECK_FLOAT(output[3], 0.2689414f, error_tolerance); // input = -1
- CHECK_FLOAT(output[4], 0.5f , error_tolerance); // input = 0
- CHECK_FLOAT(output[5], 0.7310586f, error_tolerance); // input = 1
- CHECK_FLOAT(output[6], 0.8807970f, error_tolerance); // input = 2
- CHECK_FLOAT(output[7], 0.9525740f, error_tolerance); // input = 3
-}
-
-}
diff --git a/test/tanh_test.cc b/test/tanh_test.cc
deleted file mode 100644
index 54c34fd..0000000
--- a/test/tanh_test.cc
+++ /dev/null
@@ -1,39 +0,0 @@
-#include "3rd_party/catch.hpp"
-#include "aligned.h"
-#include "postprocess.h"
-
-#include <numeric>
-
-#define CHECK_FLOAT(actual, expected, epsilon) \
- do { \
- if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
- else { CHECK((actual) == (expected)); } \
- } while(0)
-
-namespace intgemm {
-
-INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
- if (kCPU < CPUType::AVX2)
- return;
-
- const float error_tolerance = 0.001f;
-
- AlignedVector<float> input(8);
- AlignedVector<float> output(8);
-
- std::generate(input.begin(), input.end(), [] () { static int n = -4; return n++ / 4.f; });
-
- auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh());
- *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
-
- CHECK_FLOAT(output[0], -0.7615942f, error_tolerance); // input = -1
- CHECK_FLOAT(output[1], -0.6351490f, error_tolerance); // input = -0.75
- CHECK_FLOAT(output[2], -0.4621172f, error_tolerance); // input = -0.5
- CHECK_FLOAT(output[3], -0.2449187f, error_tolerance); // input = -0.25
- CHECK_FLOAT(output[4], 0.0f , error_tolerance); // input = 0
- CHECK_FLOAT(output[5], 0.2449187f, error_tolerance); // input = 0.25
- CHECK_FLOAT(output[6], 0.4621172f, error_tolerance); // input = 0.5
- CHECK_FLOAT(output[7], 0.6351490f, error_tolerance); // input = 0.75
-}
-
-}
diff --git a/test/test.cc b/test/test.cc
new file mode 100644
index 0000000..58c62f8
--- /dev/null
+++ b/test/test.cc
@@ -0,0 +1,6 @@
+#define CATCH_CONFIG_RUNNER
+#include "test/test.h"
+
+int main(int argc, char ** argv) {
+ return Catch::Session().run(argc, argv);
+}
diff --git a/test/test.h b/test/test.h
new file mode 100644
index 0000000..572a529
--- /dev/null
+++ b/test/test.h
@@ -0,0 +1,12 @@
+#include "3rd_party/catch.hpp"
+
+#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while(0)
+#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while(0)
+#define REQUIRE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE(cond); } while(0)
+#define REQUIRE_FALSE_MESSAGE(cond, msg) do { INFO(msg); REQUIRE_FALSE(cond); } while(0)
+
+#define CHECK_EPS(actual, expected, epsilon) \
+ do { \
+ if (fabs((actual) - (expected)) < epsilon) { SUCCEED(); } \
+ else { CHECK((actual) == (expected)); } \
+ } while(0)
diff --git a/test/utils_test.cc b/test/utils_test.cc
new file mode 100644
index 0000000..580a872
--- /dev/null
+++ b/test/utils_test.cc
@@ -0,0 +1,38 @@
+#include "test/test.h"
+#include "utils.h"
+
+namespace intgemm {
+namespace {
+
+TEST_CASE("Factorial",) {
+ CHECK(factorial(0) == 1);
+ CHECK(factorial(1) == 1);
+ CHECK(factorial(2) == 2);
+ CHECK(factorial(3) == 6);
+ CHECK(factorial(4) == 24);
+
+ // Maximum result that fits in unsinged long long
+ CHECK(factorial(20) == 2432902008176640000);
+}
+
+TEST_CASE("Expi (negative)",) {
+ const double eps = 0.0000001;
+ CHECK_EPS(expi(-1), 0.3678794411714423, eps);
+ CHECK_EPS(expi(-2), 0.1353352832366127, eps);
+ CHECK_EPS(expi(-10), 0.0000453999297625, eps);
+}
+
+TEST_CASE("Expi (zero)",) {
+ const double eps = 0.0000001;
+ CHECK_EPS(expi(0), 1.0, eps);
+}
+
+TEST_CASE("Expi (positive)",) {
+ const double eps = 0.0000001;
+ CHECK_EPS(expi(1), 2.7182818284590452, eps);
+ CHECK_EPS(expi(2), 7.3890560989306502, eps);
+ CHECK_EPS(expi(10), 22026.4657948067165170, eps);
+}
+
+}
+}