Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-04 16:19:19 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-04 16:19:19 +0300
commit17adc99178498a826886df29974ed7cf36ab76ea (patch)
treee9fa4468af2167f4687d5b1552b62904d929c682 /test/postprocess
parent639cc9540e8278420a16f7ee298161b18354cb1b (diff)
Change ReLU 8/16bits input to a single register
Diffstat (limited to 'test/postprocess')
-rw-r--r--test/postprocess/relu_test.cc18
1 files changed, 7 insertions, 11 deletions
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
index 43add73..af6677e 100644
--- a/test/postprocess/relu_test.cc
+++ b/test/postprocess/relu_test.cc
@@ -97,15 +97,13 @@ INTGEMM_SSE2 TEST_CASE("ReLU_int8 SSE2",) {
const unsigned NEGATIVE_NUMBERS = 10;
- AlignedVector<int8_t> input(32);
- AlignedVector<int8_t> output(32);
+ AlignedVector<int8_t> input(16);
+ AlignedVector<int8_t> output(16);
std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
auto postproc = PostprocessImpl<ReLU_int8, CPUType::SSE2>(ReLU_int8());
- auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
- output.as<__m128i>()[0] = output_tmp.pack0123;
- output.as<__m128i>()[1] = output_tmp.pack4567;
+ *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0);
for (auto i = 0; i < output.size(); ++i)
CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
@@ -158,17 +156,15 @@ INTGEMM_SSE2 TEST_CASE("ReLU_int16 SSE2",) {
if (kCPU < CPUType::SSE2)
return;
- const unsigned NEGATIVE_NUMBERS = 10;
+ const unsigned NEGATIVE_NUMBERS = 5;
- AlignedVector<int16_t> input(16);
- AlignedVector<int16_t> output(16);
+ AlignedVector<int16_t> input(8);
+ AlignedVector<int16_t> output(8);
std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
auto postproc = PostprocessImpl<ReLU_int16, CPUType::SSE2>(ReLU_int16());
- auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
- output.as<__m128i>()[0] = output_tmp.pack0123;
- output.as<__m128i>()[1] = output_tmp.pack4567;
+ *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0);
for (auto i = 0; i < output.size(); ++i)
CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));