diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-04 16:19:19 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-04 16:19:19 +0300 |
commit | 17adc99178498a826886df29974ed7cf36ab76ea (patch) | |
tree | e9fa4468af2167f4687d5b1552b62904d929c682 | |
parent | 639cc9540e8278420a16f7ee298161b18354cb1b (diff) |
Change ReLU 8/16bits input to a single register
-rw-r--r-- | postprocess.h | 18 | ||||
-rw-r--r-- | test/postprocess/relu_test.cc | 18 |
2 files changed, 13 insertions, 23 deletions
diff --git a/postprocess.h b/postprocess.h index 9acc008..7835b2b 100644 --- a/postprocess.h +++ b/postprocess.h @@ -217,17 +217,14 @@ class ReLU_int8 {}; template <> class PostprocessImpl<ReLU_int8, CPUType::SSE2> { public: - using InputRegister = RegisterPair128i; - using OutputRegister = RegisterPair128i; + using InputRegister = __m128i; + using OutputRegister = __m128i; PostprocessImpl(const ReLU_int8& config) {} INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { static const auto const_zero = setzero_si<__m128i>(); - return { - _mm_and_si128(_mm_cmplt_epi8(const_zero, input.pack0123), input.pack0123), - _mm_and_si128(_mm_cmplt_epi8(const_zero, input.pack4567), input.pack4567), - }; + return _mm_and_si128(_mm_cmplt_epi8(const_zero, input), input); } }; @@ -271,17 +268,14 @@ class ReLU_int16 {}; template <> class PostprocessImpl<ReLU_int16, CPUType::SSE2> { public: - using InputRegister = RegisterPair128i; - using OutputRegister = RegisterPair128i; + using InputRegister = __m128i; + using OutputRegister = __m128i; PostprocessImpl(const ReLU_int16& config) {} INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { static const auto const_zero = setzero_si<__m128i>(); - return { - max_epi16(const_zero, input.pack0123), - max_epi16(const_zero, input.pack4567), - }; + return max_epi16(const_zero, input); } }; diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc index 43add73..af6677e 100644 --- a/test/postprocess/relu_test.cc +++ b/test/postprocess/relu_test.cc @@ -97,15 +97,13 @@ INTGEMM_SSE2 TEST_CASE("ReLU_int8 SSE2",) { const unsigned NEGATIVE_NUMBERS = 10; - AlignedVector<int8_t> input(32); - AlignedVector<int8_t> output(32); + AlignedVector<int8_t> input(16); + AlignedVector<int8_t> output(16); std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); auto postproc = PostprocessImpl<ReLU_int8, CPUType::SSE2>(ReLU_int8()); - auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0); - output.as<__m128i>()[0] = output_tmp.pack0123; - output.as<__m128i>()[1] = output_tmp.pack4567; + *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0); for (auto i = 0; i < output.size(); ++i) CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); @@ -158,17 +156,15 @@ INTGEMM_SSE2 TEST_CASE("ReLU_int16 SSE2",) { if (kCPU < CPUType::SSE2) return; - const unsigned NEGATIVE_NUMBERS = 10; + const unsigned NEGATIVE_NUMBERS = 5; - AlignedVector<int16_t> input(16); - AlignedVector<int16_t> output(16); + AlignedVector<int16_t> input(8); + AlignedVector<int16_t> output(8); std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); auto postproc = PostprocessImpl<ReLU_int16, CPUType::SSE2>(ReLU_int16()); - auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0); - output.as<__m128i>()[0] = output_tmp.pack0123; - output.as<__m128i>()[1] = output_tmp.pack4567; + *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0); for (auto i = 0; i < output.size(); ++i) CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); |