diff options
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | intrinsics.h | 19 | ||||
-rw-r--r-- | kernels/implementations.inl | 45 | ||||
-rw-r--r-- | test/kernels/relu_test.cc | 50 |
4 files changed, 115 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c6d96e..93018c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,7 @@ add_executable(tests test/kernels/exp_test.cc test/kernels/floor_ff_test.cc test/kernels/quantize_test.cc + test/kernels/relu_test.cc test/kernels/unquantize_test.cc test/kernels/write_test.cc diff --git a/intrinsics.h b/intrinsics.h index 293efc3..2958aaa 100644 --- a/intrinsics.h +++ b/intrinsics.h @@ -18,6 +18,7 @@ namespace intgemm { template <class Register> static inline Register loadu_ps(const float* mem_addr); template <class Register> static inline Register set1_epi16(int16_t to); template <class Register> static inline Register set1_epi32(int32_t to); +template <class Register> static inline Register set1_pd(double to); template <class Register> static inline Register set1_ps(float to); template <class Register> static inline Register setzero_ps(); template <class Register> static inline Register setzero_si(); @@ -72,6 +73,9 @@ INTGEMM_SSSE3 static inline __m128i maddubs_epi16(__m128i first, __m128i second) INTGEMM_SSE2 static inline __m128i max_epi16(__m128i first, __m128i second) { return _mm_max_epi16(first, second); } +INTGEMM_SSE2 static inline __m128d max_pd(__m128d first, __m128d second) { + return _mm_max_pd(first, second); +} INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) { return _mm_max_ps(first, second); } @@ -87,6 +91,9 @@ template <> INTGEMM_SSE2 inline __m128i set1_epi16<__m128i>(int16_t to) { template <> INTGEMM_SSE2 inline __m128i set1_epi32<__m128i>(int32_t to) { return _mm_set1_epi32(to); } +template <> INTGEMM_SSE2 inline __m128d set1_pd<__m128d>(double to) { + return _mm_set1_pd(to); +} template <> INTGEMM_SSE2 inline __m128 set1_ps<__m128>(float to) { return _mm_set1_ps(to); } @@ -157,6 +164,9 @@ INTGEMM_AVX2 static inline __m256i max_epi8(__m256i first, __m256i second) { INTGEMM_AVX2 static inline __m256i max_epi16(__m256i first, __m256i second) { return _mm256_max_epi16(first, second); } +INTGEMM_AVX2 static inline __m256d max_pd(__m256d first, __m256d second) { + return _mm256_max_pd(first, second); +} INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) { return _mm256_max_ps(first, second); } @@ -172,6 +182,9 @@ template <> INTGEMM_AVX2 inline __m256i set1_epi16<__m256i>(int16_t to) { template <> INTGEMM_AVX2 inline __m256i set1_epi32<__m256i>(int32_t to) { return _mm256_set1_epi32(to); } +template <> INTGEMM_AVX2 inline __m256d set1_pd<__m256d>(double to) { + return _mm256_set1_pd(to); +} template <> INTGEMM_AVX2 inline __m256 set1_ps<__m256>(float to) { return _mm256_set1_ps(to); } @@ -244,6 +257,9 @@ INTGEMM_AVX512BW static inline __m512i max_epi8(__m512i first, __m512i second) { INTGEMM_AVX512BW static inline __m512i max_epi16(__m512i first, __m512i second) { return _mm512_max_epi16(first, second); } +INTGEMM_AVX512BW static inline __m512d max_pd(__m512d first, __m512d second) { + return _mm512_max_pd(first, second); +} INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) { return _mm512_max_ps(first, second); } @@ -259,6 +275,9 @@ template <> inline INTGEMM_AVX512BW __m512i set1_epi16<__m512i>(int16_t to) { template <> inline INTGEMM_AVX512BW __m512i set1_epi32<__m512i>(int32_t to) { return _mm512_set1_epi32(to); } +template <> inline INTGEMM_AVX512BW __m512d set1_pd<__m512d>(double to) { + return _mm512_set1_pd(to); +} template <> inline INTGEMM_AVX512BW __m512 set1_ps<__m512>(float to) { return _mm512_set1_ps(to); } diff --git a/kernels/implementations.inl b/kernels/implementations.inl index 6617cdd..9c570ad 100644 --- a/kernels/implementations.inl +++ b/kernels/implementations.inl @@ -105,6 +105,51 @@ CPU_ATTR static inline dvf add_bias(dvf input, const float* bias_addr, Index bia } /* + * ReLU + */ +CPU_ATTR static inline vi relu(vi input) { + static const auto vconst_zero = set1_epi32<vi>(0); +#if defined(THIS_IS_SSE2) + return _mm_and_si128(input, _mm_cmplt_epi32(vconst_zero, input)); +#elif defined(THIS_IS_AVX2) + return _mm256_max_epi32(input, vconst_zero); +#else + return _mm512_max_epi32(input, vconst_zero); +#endif +} + +CPU_ATTR static inline dvi relu(dvi input) { + return { + relu(input.first), + relu(input.second), + }; +} + +CPU_ATTR static inline vf relu(vf input) { + static const auto vconst_zero = set1_ps<vf>(0); + return max_ps(input, vconst_zero); +} + +CPU_ATTR static inline dvf relu(dvf input) { + return { + relu(input.first), + relu(input.second), + }; +} + +CPU_ATTR static inline vd relu(vd input) { + static const auto vconst_zero = set1_pd<vd>(0); + return max_pd(input, vconst_zero); +} + +CPU_ATTR static inline dvd relu(dvd input) { + return { + relu(input.first), + relu(input.second), + }; +} + +/* * Calculate floor: float -> float */ CPU_ATTR static inline vf floor_ff(vf a) { diff --git a/test/kernels/relu_test.cc b/test/kernels/relu_test.cc new file mode 100644 index 0000000..e5d919e --- /dev/null +++ b/test/kernels/relu_test.cc @@ -0,0 +1,50 @@ +#include "test/test.h" +#include "aligned.h" +#include "kernels.h" + +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_, typename ElemType_> +void kernel_relu_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, ElemType_>; + constexpr static auto VECTOR_LENGTH = sizeof(vec_t) / sizeof(ElemType_); + + AlignedVector<ElemType_> input(VECTOR_LENGTH); + AlignedVector<ElemType_> output(VECTOR_LENGTH); + + std::iota(input.begin(), input.end(), -int(VECTOR_LENGTH / 2)); + + *output.template as<vec_t>() = kernels::relu(*input.template as<vec_t>()); + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == (input[i] < 0 ? 0 : input[i])); +} + +template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int>(); +template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, float>(); +template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, double>(); +TEST_CASE("Kernel: relu/int SSE2",) { return kernel_relu_test<CPUType::SSE2, int>(); } +TEST_CASE("Kernel: relu/float SSE2",) { return kernel_relu_test<CPUType::SSE2, float>(); } +TEST_CASE("Kernel: relu/double SSE2",) { return kernel_relu_test<CPUType::SSE2, double>(); } + +template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int>(); +template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, float>(); +template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, double>(); +TEST_CASE("Kernel: relu/int AVX2",) { return kernel_relu_test<CPUType::AVX2, int>(); } +TEST_CASE("Kernel: relu/float AVX2",) { return kernel_relu_test<CPUType::AVX2, float>(); } +TEST_CASE("Kernel: relu/double AVX2",) { return kernel_relu_test<CPUType::AVX2, double>(); } + +#ifndef INTGEMM_NO_AVX512 +template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int>(); +template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, float>(); +template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, double>(); +TEST_CASE("Kernel: relu/int AVX512BW",) { return kernel_relu_test<CPUType::AVX512BW, int>(); } +TEST_CASE("Kernel: relu/float AVX512BW",) { return kernel_relu_test<CPUType::AVX512BW, float>(); } +TEST_CASE("Kernel: relu/double AVX512BW",) { return kernel_relu_test<CPUType::AVX512BW, double>(); } +#endif + +} |