Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <kpu@users.noreply.github.com>2019-07-04 16:55:57 +0300
committerGitHub <noreply@github.com>2019-07-04 16:55:57 +0300
commitce292be1138ecce0ec127ed59fe79d0091be7d11 (patch)
treee9fa4468af2167f4687d5b1552b62904d929c682
parent651750defccf73c2917a062ae976ebfdd34c92e9 (diff)
parent17adc99178498a826886df29974ed7cf36ab76ea (diff)
Merge pull request #20 from kpu/add-relu-int
Add relu int
-rw-r--r--intrinsics.h18
-rw-r--r--postprocess.h110
-rw-r--r--test/postprocess/relu_test.cc125
3 files changed, 253 insertions, 0 deletions
diff --git a/intrinsics.h b/intrinsics.h
index c27ca97..293efc3 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -66,6 +66,12 @@ INTGEMM_SSE2 static inline __m128i madd_epi16(__m128i first, __m128i second) {
INTGEMM_SSSE3 static inline __m128i maddubs_epi16(__m128i first, __m128i second) {
return _mm_maddubs_epi16(first, second);
}
+/*
+ * Missing max_epi8 for SSE2
+ */
+INTGEMM_SSE2 static inline __m128i max_epi16(__m128i first, __m128i second) {
+ return _mm_max_epi16(first, second);
+}
INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) {
return _mm_max_ps(first, second);
}
@@ -145,6 +151,12 @@ INTGEMM_AVX2 static inline __m256i madd_epi16(__m256i first, __m256i second) {
INTGEMM_AVX2 static inline __m256i maddubs_epi16(__m256i first, __m256i second) {
return _mm256_maddubs_epi16(first, second);
}
+INTGEMM_AVX2 static inline __m256i max_epi8(__m256i first, __m256i second) {
+ return _mm256_max_epi8(first, second);
+}
+INTGEMM_AVX2 static inline __m256i max_epi16(__m256i first, __m256i second) {
+ return _mm256_max_epi16(first, second);
+}
INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) {
return _mm256_max_ps(first, second);
}
@@ -226,6 +238,12 @@ INTGEMM_AVX512BW static inline __m512i madd_epi16(__m512i first, __m512i second)
INTGEMM_AVX512BW static inline __m512i maddubs_epi16(__m512i first, __m512i second) {
return _mm512_maddubs_epi16(first, second);
}
+INTGEMM_AVX512BW static inline __m512i max_epi8(__m512i first, __m512i second) {
+ return _mm512_max_epi8(first, second);
+}
+INTGEMM_AVX512BW static inline __m512i max_epi16(__m512i first, __m512i second) {
+ return _mm512_max_epi16(first, second);
+}
INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) {
return _mm512_max_ps(first, second);
}
diff --git a/postprocess.h b/postprocess.h
index c4dd7ae..7835b2b 100644
--- a/postprocess.h
+++ b/postprocess.h
@@ -5,6 +5,10 @@
#include "types.h"
#include "vec_utils.h"
+// TODO: We support some postprocess in few variations e.g. we support ReLU for
+// float -> float, int8 -> int8, int16 -> int16. Maybe it would be a good idea
+// to pass input type and output type as a template parameter of postprocess?
+
namespace intgemm {
/*
@@ -203,6 +207,110 @@ public:
}
};
+#endif
+
+/*
+ * ReLU_int8
+ */
+class ReLU_int8 {};
+
+template <>
+class PostprocessImpl<ReLU_int8, CPUType::SSE2> {
+public:
+ using InputRegister = __m128i;
+ using OutputRegister = __m128i;
+
+ PostprocessImpl(const ReLU_int8& config) {}
+
+ INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m128i>();
+ return _mm_and_si128(_mm_cmplt_epi8(const_zero, input), input);
+ }
+};
+
+template <>
+class PostprocessImpl<ReLU_int8, CPUType::AVX2> {
+public:
+ using InputRegister = __m256i;
+ using OutputRegister = __m256i;
+
+ PostprocessImpl(const ReLU_int8& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m256i>();
+ return max_epi8(const_zero, input);
+ }
+};
+
+#ifndef INTGEMM_NO_AVX512
+
+template <>
+class PostprocessImpl<ReLU_int8, CPUType::AVX512BW> {
+public:
+ using InputRegister = __m512i;
+ using OutputRegister = __m512i;
+
+ PostprocessImpl(const ReLU_int8& config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m512i>();
+ return max_epi8(const_zero, input);
+ }
+};
+
+#endif
+
+/*
+ * ReLU_int16
+ */
+class ReLU_int16 {};
+
+template <>
+class PostprocessImpl<ReLU_int16, CPUType::SSE2> {
+public:
+ using InputRegister = __m128i;
+ using OutputRegister = __m128i;
+
+ PostprocessImpl(const ReLU_int16& config) {}
+
+ INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m128i>();
+ return max_epi16(const_zero, input);
+ }
+};
+
+template <>
+class PostprocessImpl<ReLU_int16, CPUType::AVX2> {
+public:
+ using InputRegister = __m256i;
+ using OutputRegister = __m256i;
+
+ PostprocessImpl(const ReLU_int16& config) {}
+
+ INTGEMM_AVX2 inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m256i>();
+ return max_epi16(const_zero, input);
+ }
+};
+
+#ifndef INTGEMM_NO_AVX512
+
+template <>
+class PostprocessImpl<ReLU_int16, CPUType::AVX512BW> {
+public:
+ using InputRegister = __m512i;
+ using OutputRegister = __m512i;
+
+ PostprocessImpl(const ReLU_int16& config) {}
+
+ INTGEMM_AVX512BW inline OutputRegister run(InputRegister input, Index offset) {
+ static const auto const_zero = setzero_si<__m512i>();
+ return max_epi16(const_zero, input);
+ }
+};
+
+#endif
+
/*
* Sigmoid (uses Taylor series approximation of e^x)
*/
@@ -256,6 +364,8 @@ public:
}
};
+#ifndef INTGEMM_NO_AVX512
+
template <>
class PostprocessImpl<Tanh, CPUType::AVX512BW> {
public:
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
index e2f2d11..af6677e 100644
--- a/test/postprocess/relu_test.cc
+++ b/test/postprocess/relu_test.cc
@@ -6,6 +6,9 @@
namespace intgemm {
+/*
+ * ReLU: float -> float
+ */
INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
if (kCPU < CPUType::SSE2)
return;
@@ -85,4 +88,126 @@ INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) {
#endif
+/*
+ * ReLU: int8 -> int8
+ */
+INTGEMM_SSE2 TEST_CASE("ReLU_int8 SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int8_t> input(16);
+ AlignedVector<int8_t> output(16);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int8, CPUType::SSE2>(ReLU_int8());
+ *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+INTGEMM_AVX2 TEST_CASE("ReLU_int8 AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int8_t> input(32);
+ AlignedVector<int8_t> output(32);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX2>(ReLU_int8());
+ *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("ReLU_int8 AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 30;
+
+ AlignedVector<int8_t> input(64);
+ AlignedVector<int8_t> output(64);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int8, CPUType::AVX512BW>(ReLU_int8());
+ *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#endif
+
+/*
+ * ReLU: int16 -> int16
+ */
+INTGEMM_SSE2 TEST_CASE("ReLU_int16 SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 5;
+
+ AlignedVector<int16_t> input(8);
+ AlignedVector<int16_t> output(8);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::SSE2>(ReLU_int16());
+ *output.as<__m128i>() = postproc.run(*input.as<__m128i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+INTGEMM_AVX2 TEST_CASE("ReLU_int16 AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 10;
+
+ AlignedVector<int16_t> input(16);
+ AlignedVector<int16_t> output(16);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX2>(ReLU_int16());
+ *output.as<__m256i>() = postproc.run(*input.as<__m256i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("ReLU_int16 AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ const unsigned NEGATIVE_NUMBERS = 15;
+
+ AlignedVector<int16_t> input(32);
+ AlignedVector<int16_t> output(32);
+
+ std::iota(input.begin(), input.end(), -NEGATIVE_NUMBERS);
+
+ auto postproc = PostprocessImpl<ReLU_int16, CPUType::AVX512BW>(ReLU_int16());
+ *output.as<__m512i>() = postproc.run(*input.as<__m512i>(), 0);
+
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (i <= NEGATIVE_NUMBERS ? 0 : i - NEGATIVE_NUMBERS));
+}
+
+#endif
+
}