Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-04 16:19:19 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-04 16:19:19 +0300
commit17adc99178498a826886df29974ed7cf36ab76ea (patch)
treee9fa4468af2167f4687d5b1552b62904d929c682
parent639cc9540e8278420a16f7ee298161b18354cb1b (diff)
Change ReLU 8/16bits input to a single register
-rw-r--r--postprocess.h18
-rw-r--r--test/postprocess/relu_test.cc18
2 files changed, 13 insertions, 23 deletions
diff --git a/postprocess.h b/postprocess.h
index 9acc008..7835b2b 100644
--- a/postprocess.h
+++ b/postprocess.h
@@ -217,17 +217,14 @@ class ReLU_int8 {};
template <>
class PostprocessImpl<ReLU_int8, CPUType::SSE2> {
public:
- using InputRegister = RegisterPair128i;
- using OutputRegister = RegisterPair128i;
+ using InputRegister = __m128i;
+ using OutputRegister = __m128i;
PostprocessImpl(const ReLU_int8& config) {}
INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
static const auto const_zero = setzero_si<__m128i>();
- return {
- _mm_and_si128(_mm_cmplt_epi8(const_zero, input.pack0123), input.pack0123),
- _mm_and_si128(_mm_cmplt_epi8(const_zero, input.pack4567), input.pack4567),
- };
+ return _mm_and_si128(_mm_cmplt_epi8(const_zero, input), input);
}
};
@@ -271,17 +268,14 @@ class ReLU_int16 {};
template <>
class PostprocessImpl<ReLU_int16, CPUType::SSE2> {
public:
- using InputRegister = RegisterPair128i;
- using OutputRegister = RegisterPair128i;
+ using InputRegister = __m128i;
+ using OutputRegister = __m128i;
PostprocessImpl(const ReLU_int16& config) {}
INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
static const auto const_zero = setzero_si<__m128i>();
- return {
- max_epi16(const_zero, input.pack0123),
- max_epi16(const_zero, input.pack4567),
- };
+ return max_epi16(const_zero, input);
}
};
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
index 43add73..af6677e 100644
--- a/test/postprocess/relu_test.cc
+++ b/test/postprocess/relu_test.cc
@@ -97,15 +97,13 @@ INTGEMM_SSE2 TEST_CASE("ReLU_int8 SSE2",) {
const unsigned NEGATIVE_NUMBERS = 10;
- AlignedVector<int8_t> input(32);
- AlignedVector<int8_t> output(32);
+ AlignedVector<int8_t> input(16);
+ AlignedVector<int8_t> output(16);
std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
auto postproc = PostprocessImpl<ReLU_int8, CPUType::SSE2>(ReLU_int8());
- auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
- output.as<__m128i>()[0] = output_tmp.pack0123;
- output.as<__m128i>()[1] = output_tmp.pack4567;
+ *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0);
for (auto i = 0; i < output.size(); ++i)
CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
@@ -158,17 +156,15 @@ INTGEMM_SSE2 TEST_CASE("ReLU_int16 SSE2",) {
if (kCPU < CPUType::SSE2)
return;
- const unsigned NEGATIVE_NUMBERS = 10;
+ const unsigned NEGATIVE_NUMBERS = 5;
- AlignedVector<int16_t> input(16);
- AlignedVector<int16_t> output(16);
+ AlignedVector<int16_t> input(8);
+ AlignedVector<int16_t> output(8);
std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
auto postproc = PostprocessImpl<ReLU_int16, CPUType::SSE2>(ReLU_int16());
- auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
- output.as<__m128i>()[0] = output_tmp.pack0123;
- output.as<__m128i>()[1] = output_tmp.pack4567;
+ *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0);
for (auto i = 0; i < output.size(); ++i)
CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));