Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-03 15:45:58 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-03 16:41:11 +0300
commitabd450f4f381932d0aa56bc652b498963c3853d3 (patch)
tree2ae500ea89e6133bbb8ab7fc2875792d00253fdf
parent135f69573705df7ebf0935ac1bf86fed1ed7e685 (diff)
Add ReLU postprocessing for int8 (SSE2, AVX2, AVX512)
-rw-r--r--intrinsics.h9
-rw-r--r--postprocess.h54
-rw-r--r--test/postprocess/relu_test.cc66
3 files changed, 129 insertions, 0 deletions
diff --git a/intrinsics.h b/intrinsics.h
index c27ca97..81b5f5e 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -66,6 +66,9 @@ INTGEMM_SSE2 static inline __m128i madd_epi16(__m128i first, __m128i second) {
INTGEMM_SSSE3 static inline __m128i maddubs_epi16(__m128i first, __m128i second) {
return _mm_maddubs_epi16(first, second);
}
+/*
+ * Missing max_epi8 for SSE2
+ */
INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) {
return _mm_max_ps(first, second);
}
@@ -145,6 +148,9 @@ INTGEMM_AVX2 static inline __m256i madd_epi16(__m256i first, __m256i second) {
INTGEMM_AVX2 static inline __m256i maddubs_epi16(__m256i first, __m256i second) {
return _mm256_maddubs_epi16(first, second);
}
+INTGEMM_AVX2 static inline __m256i max_epi8(__m256i first, __m256i second) {
+ return _mm256_max_epi8(first, second);
+}
INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) {
return _mm256_max_ps(first, second);
}
@@ -226,6 +232,9 @@ INTGEMM_AVX512BW static inline __m512i madd_epi16(__m512i first, __m512i second)
INTGEMM_AVX512BW static inline __m512i maddubs_epi16(__m512i first, __m512i second) {
return _mm512_maddubs_epi16(first, second);
}
+INTGEMM_AVX512BW static inline __m512i max_epi8(__m512i first, __m512i second) {
+ return _mm512_max_epi8(first, second);
+}
INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) {
return _mm512_max_ps(first, second);
}
diff --git a/postprocess.h b/postprocess.h
index 40ce33c..44d4cfe 100644
--- a/postprocess.h
+++ b/postprocess.h
@@ -206,6 +206,60 @@ public:
#endif
/*
+ * ReLU_int8
+ */
+class ReLU_int8 {};
+
+template <>
+class PostprocessImpl<ReLU_int8, CPUType::SSE2> {
+public:
+ using InputRegister = RegisterPair128i;
+ using OutputRegister = RegisterPair128i;
+
+ PostprocessImpl(const ReLU_int8& config) {}
+
+ INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m128i>();
+ return {
+ _mm_and_si128(_mm_cmplt_epi8(const_zero, input.pack0123), input.pack0123),
+ _mm_and_si128(_mm_cmplt_epi8(const_zero, input.pack4567), input.pack4567),
+ };
+ }
+};
+
+template <>
+class PostprocessImpl<ReLU_int8, CPUType::AVX2> {
+public:
+ using InputRegister = __m256i;
+ using OutputRegister = __m256i;
+
+ PostprocessImpl(const ReLU_int8& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m256i>();
+ return max_epi8(const_zero, input);
+ }
+};
+
+#ifndef INTGEMM_NO_AVX512
+
+template <>
+class PostprocessImpl<ReLU_int8, CPUType::AVX512BW> {
+public:
+ using InputRegister = __m512i;
+ using OutputRegister = __m512i;
+
+ PostprocessImpl(const ReLU_int8& config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m512i>();
+ return max_epi8(const_zero, input);
+ }
+};
+
+#endif
+
+/*
* Sigmoid (uses Taylor series approximation of e^x)
*/
class Sigmoid {};
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
index e2f2d11..eec97bf 100644
--- a/test/postprocess/relu_test.cc
+++ b/test/postprocess/relu_test.cc
@@ -6,6 +6,9 @@
namespace intgemm {
+/*
+ * ReLU: float -> float
+ */
INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
if (kCPU < CPUType::SSE2)
return;
@@ -85,4 +88,67 @@ INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) {
#endif
+/*
+ * ReLU: int8 -> int8
+ */
+INTGEMM_SSE2 TEST_CASE("ReLU_int8 SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int8_t> input(32);
+ AlignedVector<int8_t> output(32);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int8, CPUType::SSE2>(ReLU_int8());
+ auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
+ output.as<__m128i>()[0] = output_tmp.pack0123;
+ output.as<__m128i>()[1] = output_tmp.pack4567;
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+INTGEMM_AVX2 TEST_CASE("ReLU_int8 AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int8_t> input(32);
+ AlignedVector<int8_t> output(32);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX2>(ReLU_int8());
+ *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("ReLU_int8 AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 30;
+
+ AlignedVector<int8_t> input(64);
+ AlignedVector<int8_t> output(64);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX512BW>(ReLU_int8());
+ *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#endif
+
}