Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-22 18:20:25 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-22 18:28:00 +0300
commitedabfc96e5576479e7f88b4c6bfee75c7dfda9bd (patch)
treec95d210b1e9b8402c18f9ca4a6381ad99204dd02 /test/kernels/multiply_test.cc
parent721f4802464431dfecbc7c4bed68850f81b7af70 (diff)
Add multiply (elemwise) kernel
Diffstat (limited to 'test/kernels/multiply_test.cc')
-rw-r--r--test/kernels/multiply_test.cc64
1 files changed, 64 insertions, 0 deletions
diff --git a/test/kernels/multiply_test.cc b/test/kernels/multiply_test.cc
new file mode 100644
index 0000000..9673e89
--- /dev/null
+++ b/test/kernels/multiply_test.cc
@@ -0,0 +1,64 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "kernels.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+template <CPUType CPUType_, typename Type_>
+void kernel_multiply_test() {
+ if (kCPU < CPUType_)
+ return;
+
+ using vec_t = vector_t<CPUType_, Type_>;
+ constexpr static auto VECTOR_LENGTH = sizeof(vec_t) / sizeof(Type_);
+
+ AlignedVector<Type_> input1(VECTOR_LENGTH);
+ AlignedVector<Type_> input2(VECTOR_LENGTH);
+ AlignedVector<Type_> output(VECTOR_LENGTH);
+
+ std::iota(input1.begin(), input1.end(), -int(VECTOR_LENGTH / 2));
+ std::iota(input2.begin(), input2.end(), -int(VECTOR_LENGTH / 3));
+
+ *output.template as<vec_t>() = kernels::multiply<Type_>(*input1.template as<vec_t>(), *input2.template as<vec_t>());
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == Type_(input1[i] * input2[i]));
+}
+
+template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int8_t>();
+template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int16_t>();
+template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int>();
+template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, float>();
+template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, double>();
+KERNEL_TEST_CASE("multiply/int8 SSE2") { return kernel_multiply_test<CPUType::SSE2, int8_t>(); }
+KERNEL_TEST_CASE("multiply/int16 SSE2") { return kernel_multiply_test<CPUType::SSE2, int16_t>(); }
+KERNEL_TEST_CASE("multiply/int SSE2") { return kernel_multiply_test<CPUType::SSE2, int>(); }
+KERNEL_TEST_CASE("multiply/float SSE2") { return kernel_multiply_test<CPUType::SSE2, float>(); }
+KERNEL_TEST_CASE("multiply/double SSE2") { return kernel_multiply_test<CPUType::SSE2, double>(); }
+
+template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int8_t>();
+template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int16_t>();
+template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int>();
+template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, float>();
+template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, double>();
+KERNEL_TEST_CASE("multiply/int8 AVX2") { return kernel_multiply_test<CPUType::AVX2, int8_t>(); }
+KERNEL_TEST_CASE("multiply/int16 AVX2") { return kernel_multiply_test<CPUType::AVX2, int16_t>(); }
+KERNEL_TEST_CASE("multiply/int AVX2") { return kernel_multiply_test<CPUType::AVX2, int>(); }
+KERNEL_TEST_CASE("multiply/float AVX2") { return kernel_multiply_test<CPUType::AVX2, float>(); }
+KERNEL_TEST_CASE("multiply/double AVX2") { return kernel_multiply_test<CPUType::AVX2, double>(); }
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int8_t>();
+template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int16_t>();
+template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int>();
+template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, float>();
+template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, double>();
+KERNEL_TEST_CASE("multiply/int8 AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int8_t>(); }
+KERNEL_TEST_CASE("multiply/int16 AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int16_t>(); }
+KERNEL_TEST_CASE("multiply/int AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int>(); }
+KERNEL_TEST_CASE("multiply/float AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, float>(); }
+KERNEL_TEST_CASE("multiply/double AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, double>(); }
+#endif
+
+}