Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt1
-rw-r--r--intrinsics.h19
-rw-r--r--kernels/implementations.inl45
-rw-r--r--test/kernels/relu_test.cc50
4 files changed, 115 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1c6d96e..93018c4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -43,6 +43,7 @@ add_executable(tests
test/kernels/exp_test.cc
test/kernels/floor_ff_test.cc
test/kernels/quantize_test.cc
+ test/kernels/relu_test.cc
test/kernels/unquantize_test.cc
test/kernels/write_test.cc
diff --git a/intrinsics.h b/intrinsics.h
index 293efc3..2958aaa 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -18,6 +18,7 @@ namespace intgemm {
template <class Register> static inline Register loadu_ps(const float* mem_addr);
template <class Register> static inline Register set1_epi16(int16_t to);
template <class Register> static inline Register set1_epi32(int32_t to);
+template <class Register> static inline Register set1_pd(double to);
template <class Register> static inline Register set1_ps(float to);
template <class Register> static inline Register setzero_ps();
template <class Register> static inline Register setzero_si();
@@ -72,6 +73,9 @@ INTGEMM_SSSE3 static inline __m128i maddubs_epi16(__m128i first, __m128i second)
INTGEMM_SSE2 static inline __m128i max_epi16(__m128i first, __m128i second) {
return _mm_max_epi16(first, second);
}
+INTGEMM_SSE2 static inline __m128d max_pd(__m128d first, __m128d second) {
+ return _mm_max_pd(first, second);
+}
INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) {
return _mm_max_ps(first, second);
}
@@ -87,6 +91,9 @@ template <> INTGEMM_SSE2 inline __m128i set1_epi16<__m128i>(int16_t to) {
template <> INTGEMM_SSE2 inline __m128i set1_epi32<__m128i>(int32_t to) {
return _mm_set1_epi32(to);
}
+template <> INTGEMM_SSE2 inline __m128d set1_pd<__m128d>(double to) {
+ return _mm_set1_pd(to);
+}
template <> INTGEMM_SSE2 inline __m128 set1_ps<__m128>(float to) {
return _mm_set1_ps(to);
}
@@ -157,6 +164,9 @@ INTGEMM_AVX2 static inline __m256i max_epi8(__m256i first, __m256i second) {
INTGEMM_AVX2 static inline __m256i max_epi16(__m256i first, __m256i second) {
return _mm256_max_epi16(first, second);
}
+INTGEMM_AVX2 static inline __m256d max_pd(__m256d first, __m256d second) {
+ return _mm256_max_pd(first, second);
+}
INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) {
return _mm256_max_ps(first, second);
}
@@ -172,6 +182,9 @@ template <> INTGEMM_AVX2 inline __m256i set1_epi16<__m256i>(int16_t to) {
template <> INTGEMM_AVX2 inline __m256i set1_epi32<__m256i>(int32_t to) {
return _mm256_set1_epi32(to);
}
+template <> INTGEMM_AVX2 inline __m256d set1_pd<__m256d>(double to) {
+ return _mm256_set1_pd(to);
+}
template <> INTGEMM_AVX2 inline __m256 set1_ps<__m256>(float to) {
return _mm256_set1_ps(to);
}
@@ -244,6 +257,9 @@ INTGEMM_AVX512BW static inline __m512i max_epi8(__m512i first, __m512i second) {
INTGEMM_AVX512BW static inline __m512i max_epi16(__m512i first, __m512i second) {
return _mm512_max_epi16(first, second);
}
+INTGEMM_AVX512BW static inline __m512d max_pd(__m512d first, __m512d second) {
+ return _mm512_max_pd(first, second);
+}
INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) {
return _mm512_max_ps(first, second);
}
@@ -259,6 +275,9 @@ template <> inline INTGEMM_AVX512BW __m512i set1_epi16<__m512i>(int16_t to) {
template <> inline INTGEMM_AVX512BW __m512i set1_epi32<__m512i>(int32_t to) {
return _mm512_set1_epi32(to);
}
+template <> inline INTGEMM_AVX512BW __m512d set1_pd<__m512d>(double to) {
+ return _mm512_set1_pd(to);
+}
template <> inline INTGEMM_AVX512BW __m512 set1_ps<__m512>(float to) {
return _mm512_set1_ps(to);
}
diff --git a/kernels/implementations.inl b/kernels/implementations.inl
index 6617cdd..9c570ad 100644
--- a/kernels/implementations.inl
+++ b/kernels/implementations.inl
@@ -105,6 +105,51 @@ CPU_ATTR static inline dvf add_bias(dvf input, const float* bias_addr, Index bia
}
/*
+ * ReLU
+ */
+CPU_ATTR static inline vi relu(vi input) {
+ static const auto vconst_zero = set1_epi32<vi>(0);
+#if defined(THIS_IS_SSE2)
+ return _mm_and_si128(input, _mm_cmplt_epi32(vconst_zero, input));
+#elif defined(THIS_IS_AVX2)
+ return _mm256_max_epi32(input, vconst_zero);
+#else
+ return _mm512_max_epi32(input, vconst_zero);
+#endif
+}
+
+CPU_ATTR static inline dvi relu(dvi input) {
+ return {
+ relu(input.first),
+ relu(input.second),
+ };
+}
+
+CPU_ATTR static inline vf relu(vf input) {
+ static const auto vconst_zero = set1_ps<vf>(0);
+ return max_ps(input, vconst_zero);
+}
+
+CPU_ATTR static inline dvf relu(dvf input) {
+ return {
+ relu(input.first),
+ relu(input.second),
+ };
+}
+
+CPU_ATTR static inline vd relu(vd input) {
+ static const auto vconst_zero = set1_pd<vd>(0);
+ return max_pd(input, vconst_zero);
+}
+
+CPU_ATTR static inline dvd relu(dvd input) {
+ return {
+ relu(input.first),
+ relu(input.second),
+ };
+}
+
+/*
* Calculate floor: float -> float
*/
CPU_ATTR static inline vf floor_ff(vf a) {
diff --git a/test/kernels/relu_test.cc b/test/kernels/relu_test.cc
new file mode 100644
index 0000000..e5d919e
--- /dev/null
+++ b/test/kernels/relu_test.cc
@@ -0,0 +1,50 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "kernels.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+template <CPUType CPUType_, typename ElemType_>
+void kernel_relu_test() {
+ if (kCPU < CPUType_)
+ return;
+
+ using vec_t = vector_t<CPUType_, ElemType_>;
+ constexpr static auto VECTOR_LENGTH = sizeof(vec_t) / sizeof(ElemType_);
+
+ AlignedVector<ElemType_> input(VECTOR_LENGTH);
+ AlignedVector<ElemType_> output(VECTOR_LENGTH);
+
+ std::iota(input.begin(), input.end(), -int(VECTOR_LENGTH / 2));
+
+ *output.template as<vec_t>() = kernels::relu(*input.template as<vec_t>());
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (input[i] < 0 ? 0 : input[i]));
+}
+
+template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int>();
+template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, float>();
+template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, double>();
+TEST_CASE("Kernel: relu/int SSE2",) { return kernel_relu_test<CPUType::SSE2, int>(); }
+TEST_CASE("Kernel: relu/float SSE2",) { return kernel_relu_test<CPUType::SSE2, float>(); }
+TEST_CASE("Kernel: relu/double SSE2",) { return kernel_relu_test<CPUType::SSE2, double>(); }
+
+template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int>();
+template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, float>();
+template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, double>();
+TEST_CASE("Kernel: relu/int AVX2",) { return kernel_relu_test<CPUType::AVX2, int>(); }
+TEST_CASE("Kernel: relu/float AVX2",) { return kernel_relu_test<CPUType::AVX2, float>(); }
+TEST_CASE("Kernel: relu/double AVX2",) { return kernel_relu_test<CPUType::AVX2, double>(); }
+
+#ifndef INTGEMM_NO_AVX512
+template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int>();
+template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, float>();
+template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, double>();
+TEST_CASE("Kernel: relu/int AVX512BW",) { return kernel_relu_test<CPUType::AVX512BW, int>(); }
+TEST_CASE("Kernel: relu/float AVX512BW",) { return kernel_relu_test<CPUType::AVX512BW, float>(); }
+TEST_CASE("Kernel: relu/double AVX512BW",) { return kernel_relu_test<CPUType::AVX512BW, double>(); }
+#endif
+
+}