Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-03 15:55:19 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-03 16:41:11 +0300
commit639cc9540e8278420a16f7ee298161b18354cb1b (patch)
tree5c3922a7dd2340cc470b2d2a303fbc4953b23da5
parentabd450f4f381932d0aa56bc652b498963c3853d3 (diff)
Add ReLU postprocessing for int16 (SSE2, AVX2, AVX512)
-rw-r--r--intrinsics.h9
-rw-r--r--postprocess.h58
-rw-r--r--test/postprocess/relu_test.cc63
3 files changed, 130 insertions, 0 deletions
diff --git a/intrinsics.h b/intrinsics.h
index 81b5f5e..293efc3 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -69,6 +69,9 @@ INTGEMM_SSSE3 static inline __m128i maddubs_epi16(__m128i first, __m128i second)
/*
* Missing max_epi8 for SSE2
*/
+INTGEMM_SSE2 static inline __m128i max_epi16(__m128i first, __m128i second) {
+ return _mm_max_epi16(first, second);
+}
INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) {
return _mm_max_ps(first, second);
}
@@ -151,6 +154,9 @@ INTGEMM_AVX2 static inline __m256i maddubs_epi16(__m256i first, __m256i second)
INTGEMM_AVX2 static inline __m256i max_epi8(__m256i first, __m256i second) {
return _mm256_max_epi8(first, second);
}
+INTGEMM_AVX2 static inline __m256i max_epi16(__m256i first, __m256i second) {
+ return _mm256_max_epi16(first, second);
+}
INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) {
return _mm256_max_ps(first, second);
}
@@ -235,6 +241,9 @@ INTGEMM_AVX512BW static inline __m512i maddubs_epi16(__m512i first, __m512i seco
INTGEMM_AVX512BW static inline __m512i max_epi8(__m512i first, __m512i second) {
return _mm512_max_epi8(first, second);
}
+INTGEMM_AVX512BW static inline __m512i max_epi16(__m512i first, __m512i second) {
+ return _mm512_max_epi16(first, second);
+}
INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) {
return _mm512_max_ps(first, second);
}
diff --git a/postprocess.h b/postprocess.h
index 44d4cfe..9acc008 100644
--- a/postprocess.h
+++ b/postprocess.h
@@ -5,6 +5,10 @@
#include "types.h"
#include "vec_utils.h"
+// TODO: We support some postprocess in few variations e.g. we support ReLU for
+// float -> float, int8 -> int8, int16 -> int16. Maybe it would be a good idea
+// to pass input type and output type as a template parameter of postprocess?
+
namespace intgemm {
/*
@@ -260,6 +264,60 @@ public:
#endif
/*
+ * ReLU_int16
+ */
+class ReLU_int16 {};
+
+template <>
+class PostprocessImpl<ReLU_int16, CPUType::SSE2> {
+public:
+ using InputRegister = RegisterPair128i;
+ using OutputRegister = RegisterPair128i;
+
+ PostprocessImpl(const ReLU_int16& config) {}
+
+ INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m128i>();
+ return {
+ max_epi16(const_zero, input.pack0123),
+ max_epi16(const_zero, input.pack4567),
+ };
+ }
+};
+
+template <>
+class PostprocessImpl<ReLU_int16, CPUType::AVX2> {
+public:
+ using InputRegister = __m256i;
+ using OutputRegister = __m256i;
+
+ PostprocessImpl(const ReLU_int16& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m256i>();
+ return max_epi16(const_zero, input);
+ }
+};
+
+#ifndef INTGEMM_NO_AVX512
+
+template <>
+class PostprocessImpl<ReLU_int16, CPUType::AVX512BW> {
+public:
+ using InputRegister = __m512i;
+ using OutputRegister = __m512i;
+
+ PostprocessImpl(const ReLU_int16& config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m512i>();
+ return max_epi16(const_zero, input);
+ }
+};
+
+#endif
+
+/*
* Sigmoid (uses Taylor series approximation of e^x)
*/
class Sigmoid {};
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
index eec97bf..43add73 100644
--- a/test/postprocess/relu_test.cc
+++ b/test/postprocess/relu_test.cc
@@ -151,4 +151,67 @@ INTGEMM_AVX512BW TEST_CASE("ReLU_int8 AVX512",) {
#endif
+/*
+ * ReLU: int16 -> int16
+ */
+INTGEMM_SSE2 TEST_CASE("ReLU_int16 SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int16_t> input(16);
+ AlignedVector<int16_t> output(16);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::SSE2>(ReLU_int16());
+ auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
+ output.as<__m128i>()[0] = output_tmp.pack0123;
+ output.as<__m128i>()[1] = output_tmp.pack4567;
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+INTGEMM_AVX2 TEST_CASE("ReLU_int16 AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int16_t> input(16);
+ AlignedVector<int16_t> output(16);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX2>(ReLU_int16());
+ *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("ReLU_int16 AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 15;
+
+ AlignedVector<int16_t> input(32);
+ AlignedVector<int16_t> output(32);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX512BW>(ReLU_int16());
+ *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#endif
+
}