Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-03 15:55:19 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-03 16:41:11 +0300
commit639cc9540e8278420a16f7ee298161b18354cb1b (patch)
tree5c3922a7dd2340cc470b2d2a303fbc4953b23da5 /test/postprocess
parentabd450f4f381932d0aa56bc652b498963c3853d3 (diff)
Add ReLU postprocessing for int16 (SSE2, AVX2, AVX512)
Diffstat (limited to 'test/postprocess')
-rw-r--r--test/postprocess/relu_test.cc63
1 files changed, 63 insertions, 0 deletions
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
index eec97bf..43add73 100644
--- a/test/postprocess/relu_test.cc
+++ b/test/postprocess/relu_test.cc
@@ -151,4 +151,67 @@ INTGEMM_AVX512BW TEST_CASE("ReLU_int8 AVX512",) {
#endif
+/*
+ * ReLU: int16 -> int16
+ */
+INTGEMM_SSE2 TEST_CASE("ReLU_int16 SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int16_t> input(16);
+ AlignedVector<int16_t> output(16);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::SSE2>(ReLU_int16());
+ auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
+ output.as<__m128i>()[0] = output_tmp.pack0123;
+ output.as<__m128i>()[1] = output_tmp.pack4567;
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+INTGEMM_AVX2 TEST_CASE("ReLU_int16 AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int16_t> input(16);
+ AlignedVector<int16_t> output(16);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX2>(ReLU_int16());
+ *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("ReLU_int16 AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 15;
+
+ AlignedVector<int16_t> input(32);
+ AlignedVector<int16_t> output(32);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX512BW>(ReLU_int16());
+ *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#endif
+
}