diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-03 15:45:58 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-03 16:41:11 +0300 |
commit | abd450f4f381932d0aa56bc652b498963c3853d3 (patch) | |
tree | 2ae500ea89e6133bbb8ab7fc2875792d00253fdf | |
parent | 135f69573705df7ebf0935ac1bf86fed1ed7e685 (diff) |
Add ReLU postprocessing for int8 (SSE2, AVX2, AVX512)
-rw-r--r-- | intrinsics.h | 9 | ||||
-rw-r--r-- | postprocess.h | 54 | ||||
-rw-r--r-- | test/postprocess/relu_test.cc | 66 |
3 files changed, 129 insertions, 0 deletions
diff --git a/intrinsics.h b/intrinsics.h index c27ca97..81b5f5e 100644 --- a/intrinsics.h +++ b/intrinsics.h @@ -66,6 +66,9 @@ INTGEMM_SSE2 static inline __m128i madd_epi16(__m128i first, __m128i second) { INTGEMM_SSSE3 static inline __m128i maddubs_epi16(__m128i first, __m128i second) { return _mm_maddubs_epi16(first, second); } +/* + * Missing max_epi8 for SSE2 + */ INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) { return _mm_max_ps(first, second); } @@ -145,6 +148,9 @@ INTGEMM_AVX2 static inline __m256i madd_epi16(__m256i first, __m256i second) { INTGEMM_AVX2 static inline __m256i maddubs_epi16(__m256i first, __m256i second) { return _mm256_maddubs_epi16(first, second); } +INTGEMM_AVX2 static inline __m256i max_epi8(__m256i first, __m256i second) { + return _mm256_max_epi8(first, second); +} INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) { return _mm256_max_ps(first, second); } @@ -226,6 +232,9 @@ INTGEMM_AVX512BW static inline __m512i madd_epi16(__m512i first, __m512i second) INTGEMM_AVX512BW static inline __m512i maddubs_epi16(__m512i first, __m512i second) { return _mm512_maddubs_epi16(first, second); } +INTGEMM_AVX512BW static inline __m512i max_epi8(__m512i first, __m512i second) { + return _mm512_max_epi8(first, second); +} INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) { return _mm512_max_ps(first, second); } diff --git a/postprocess.h b/postprocess.h index 40ce33c..44d4cfe 100644 --- a/postprocess.h +++ b/postprocess.h @@ -206,6 +206,60 @@ public: #endif /* + * ReLU_int8 + */ +class ReLU_int8 {}; + +template <> +class PostprocessImpl<ReLU_int8, CPUType::SSE2> { +public: + using InputRegister = RegisterPair128i; + using OutputRegister = RegisterPair128i; + + PostprocessImpl(const ReLU_int8& config) {} + + INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = setzero_si<__m128i>(); + return { + _mm_and_si128(_mm_cmplt_epi8(const_zero, input.pack0123), input.pack0123), + _mm_and_si128(_mm_cmplt_epi8(const_zero, input.pack4567), input.pack4567), + }; + } +}; + +template <> +class PostprocessImpl<ReLU_int8, CPUType::AVX2> { +public: + using InputRegister = __m256i; + using OutputRegister = __m256i; + + PostprocessImpl(const ReLU_int8& config) {} + + INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = setzero_si<__m256i>(); + return max_epi8(const_zero, input); + } +}; + +#ifndef INTGEMM_NO_AVX512 + +template <> +class PostprocessImpl<ReLU_int8, CPUType::AVX512BW> { +public: + using InputRegister = __m512i; + using OutputRegister = __m512i; + + PostprocessImpl(const ReLU_int8& config) {} + + INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) { + static const auto const_zero = setzero_si<__m512i>(); + return max_epi8(const_zero, input); + } +}; + +#endif + +/* * Sigmoid (uses Taylor series approximation of e^x) */ class Sigmoid {}; diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc index e2f2d11..eec97bf 100644 --- a/test/postprocess/relu_test.cc +++ b/test/postprocess/relu_test.cc @@ -6,6 +6,9 @@ namespace intgemm { +/* + * ReLU: float -> float + */ INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) { if (kCPU < CPUType::SSE2) return; @@ -85,4 +88,67 @@ INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) { #endif +/* + * ReLU: int8 -> int8 + */ +INTGEMM_SSE2 TEST_CASE("ReLU_int8 SSE2",) { + if (kCPU < CPUType::SSE2) + return; + + const unsigned NEGATIVE_NUMBERS = 10; + + AlignedVector<int8_t> input(32); + AlignedVector<int8_t> output(32); + + std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); + + auto postproc = PostprocessImpl<ReLU_int8, CPUType::SSE2>(ReLU_int8()); + auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0); + output.as<__m128i>()[0] = output_tmp.pack0123; + output.as<__m128i>()[1] = output_tmp.pack4567; + + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); +} + +INTGEMM_AVX2 TEST_CASE("ReLU_int8 AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + const unsigned NEGATIVE_NUMBERS = 10; + + AlignedVector<int8_t> input(32); + AlignedVector<int8_t> output(32); + + std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); + + auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX2>(ReLU_int8()); + *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0); + + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); +} + +#ifndef INTGEMM_NO_AVX512 + +INTGEMM_AVX512BW TEST_CASE("ReLU_int8 AVX512",) { + if (kCPU < CPUType::AVX512BW) + return; + + const unsigned NEGATIVE_NUMBERS = 30; + + AlignedVector<int8_t> input(64); + AlignedVector<int8_t> output(64); + + std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS); + + auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX512BW>(ReLU_int8()); + *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0); + + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS)); +} + +#endif + } |