diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-22 18:20:25 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-22 18:28:00 +0300 |
commit | edabfc96e5576479e7f88b4c6bfee75c7dfda9bd (patch) | |
tree | c95d210b1e9b8402c18f9ca4a6381ad99204dd02 | |
parent | 721f4802464431dfecbc7c4bed68850f81b7af70 (diff) |
Add multiply (elemwise) kernel
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | intrinsics.h | 63 | ||||
-rw-r--r-- | kernels/implementations.inl | 41 | ||||
-rw-r--r-- | test/kernels/multiply_test.cc | 64 |
4 files changed, 169 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 2a15175..59ef89b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,6 +49,7 @@ add_executable(tests test/kernels/add_bias_test.cc test/kernels/exp_test.cc test/kernels/floor_test.cc + test/kernels/multiply_test.cc test/kernels/quantize_test.cc test/kernels/relu_test.cc test/kernels/sigmoid_test.cc diff --git a/intrinsics.h b/intrinsics.h index f6c63a9..d21565e 100644 --- a/intrinsics.h +++ b/intrinsics.h @@ -97,12 +97,21 @@ INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) { INTGEMM_SSE2 static inline __m128 min_ps(__m128 a, __m128 b) { return _mm_min_ps(a, b); } +INTGEMM_SSE2 static inline __m128i mul_epu32(__m128i a, __m128i b) { + return _mm_mul_epu32(a, b); +} INTGEMM_SSE2 static inline __m128d mul_pd(__m128d a, __m128d b) { return _mm_mul_pd(a, b); } INTGEMM_SSE2 static inline __m128 mul_ps(__m128 a, __m128 b) { return _mm_mul_ps(a, b); } +INTGEMM_SSE2 static inline __m128i mullo_epi16(__m128i a, __m128i b) { + return _mm_mullo_epi16(a, b); +} +INTGEMM_SSE2 static inline __m128i or_si(__m128i a, __m128i b) { + return _mm_or_si128(a, b); +} template <> INTGEMM_SSE2 inline __m128i set1_epi8<__m128i>(int8_t to) { return _mm_set1_epi8(to); } @@ -127,9 +136,18 @@ template <> INTGEMM_SSE2 inline __m128 setzero_ps<__m128>() { template <> INTGEMM_SSE2 inline __m128i setzero_si<__m128i>() { return _mm_setzero_si128(); } +INTGEMM_SSE2 static inline __m128i shuffle_epi32(__m128i a, int imm8) { + return _mm_shuffle_epi32(a, imm8); +} INTGEMM_SSSE3 static inline __m128i sign_epi8(__m128i first, __m128i second) { return _mm_sign_epi8(first, second); } +INTGEMM_SSE2 static inline __m128i slli_epi16(__m128i a, int8_t b) { + return _mm_slli_epi16(a, b); +} +INTGEMM_SSE2 static inline __m128i srli_epi16(__m128i a, int8_t b) { + return _mm_srli_epi16(a, b); +} INTGEMM_SSE2 static inline void storeu_ps(float* mem_addr, __m128 a) { _mm_storeu_ps(mem_addr, a); } @@ -139,6 +157,9 @@ INTGEMM_SSE2 static inline __m128d sub_pd(__m128d a, __m128d b) { INTGEMM_SSE2 static inline __m128 sub_ps(__m128 a, __m128 b) { return _mm_sub_ps(a, b); } +INTGEMM_SSE2 static inline __m128i unpacklo_epi32(__m128i a, __m128i b) { + return _mm_unpacklo_epi32(a, b); +} /* * @@ -212,12 +233,21 @@ INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) { INTGEMM_AVX2 static inline __m256 min_ps(__m256 a, __m256 b) { return _mm256_min_ps(a, b); } +INTGEMM_AVX2 static inline __m256i mul_epu32(__m256i a, __m256i b) { + return _mm256_mul_epu32(a, b); +} INTGEMM_AVX2 static inline __m256d mul_pd(__m256d a, __m256d b) { return _mm256_mul_pd(a, b); } INTGEMM_AVX2 static inline __m256 mul_ps(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); } +INTGEMM_AVX2 static inline __m256i mullo_epi16(__m256i a, __m256i b) { + return _mm256_mullo_epi16(a, b); +} +INTGEMM_AVX2 static inline __m256i or_si(__m256i a, __m256i b) { + return _mm256_or_si256(a, b); +} template <> INTGEMM_AVX2 inline __m256i set1_epi8<__m256i>(int8_t to) { return _mm256_set1_epi8(to); } @@ -242,9 +272,18 @@ template <> INTGEMM_AVX2 inline __m256 setzero_ps<__m256>() { template <> INTGEMM_AVX2 inline __m256i setzero_si<__m256i>() { return _mm256_setzero_si256(); } +INTGEMM_AVX2 static inline __m256i shuffle_epi32(__m256i a, int imm8) { + return _mm256_shuffle_epi32(a, imm8); +} INTGEMM_AVX2 static inline __m256i sign_epi8(__m256i first, __m256i second) { return _mm256_sign_epi8(first, second); } +INTGEMM_AVX2 static inline __m256i slli_epi16(__m256i a, int8_t b) { + return _mm256_slli_epi16(a, b); +} +INTGEMM_AVX2 static inline __m256i srli_epi16(__m256i a, int8_t b) { + return _mm256_srli_epi16(a, b); +} INTGEMM_AVX2 static inline void storeu_ps(float* mem_addr, __m256 a) { _mm256_storeu_ps(mem_addr, a); } @@ -254,6 +293,9 @@ INTGEMM_AVX2 static inline __m256d sub_pd(__m256d a, __m256d b) { INTGEMM_AVX2 static inline __m256 sub_ps(__m256 a, __m256 b) { return _mm256_sub_ps(a, b); } +INTGEMM_AVX2 static inline __m256i unpacklo_epi32(__m256i a, __m256i b) { + return _mm256_unpacklo_epi32(a, b); +} /* * @@ -329,12 +371,21 @@ INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) { INTGEMM_AVX512BW static inline __m512 min_ps(__m512 a, __m512 b) { return _mm512_min_ps(a, b); } +INTGEMM_AVX512BW static inline __m512i mul_epu32(__m512i a, __m512i b) { + return _mm512_mul_epu32(a, b); +} INTGEMM_AVX512BW static inline __m512d mul_pd(__m512d a, __m512d b) { return _mm512_mul_pd(a, b); } INTGEMM_AVX512BW static inline __m512 mul_ps(__m512 a, __m512 b) { return _mm512_mul_ps(a, b); } +INTGEMM_AVX512BW static inline __m512i mullo_epi16(__m512i a, __m512i b) { + return _mm512_mullo_epi16(a, b); +} +INTGEMM_AVX512BW static inline __m512i or_si(__m512i a, __m512i b) { + return _mm512_or_si512(a, b); +} template <> inline INTGEMM_AVX512BW __m512i set1_epi8<__m512i>(int8_t to) { return _mm512_set1_epi8(to); } @@ -362,6 +413,15 @@ template <> INTGEMM_AVX512BW inline __m512i setzero_si<__m512i>() { /* * Missing sign_epi8 */ +INTGEMM_AVX512BW static inline __m512i shuffle_epi32(__m512i a, _MM_PERM_ENUM imm8) { + return _mm512_shuffle_epi32(a, imm8); +} +INTGEMM_AVX512BW static inline __m512i slli_epi16(__m512i a, int8_t b) { + return _mm512_slli_epi16(a, b); +} +INTGEMM_AVX512BW static inline __m512i srli_epi16(__m512i a, int8_t b) { + return _mm512_srli_epi16(a, b); +} INTGEMM_AVX512BW static inline void storeu_ps(float* mem_addr, __m512 a) { _mm512_storeu_ps(mem_addr, a); } @@ -371,6 +431,9 @@ INTGEMM_AVX512BW static inline __m512d sub_pd(__m512d a, __m512d b) { INTGEMM_AVX512BW static inline __m512 sub_ps(__m512 a, __m512 b) { return _mm512_sub_ps(a, b); } +INTGEMM_AVX512BW static inline __m512i unpacklo_epi32(__m512i a, __m512i b) { + return _mm512_unpacklo_epi32(a, b); +} #endif diff --git a/kernels/implementations.inl b/kernels/implementations.inl index fd46390..e2565b3 100644 --- a/kernels/implementations.inl +++ b/kernels/implementations.inl @@ -142,6 +142,47 @@ CPU_ATTR inline vd relu<double>(vd input) { } /* + * Multiply (elemwise) + */ +template <typename Type> +CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> multiply(vector_t<CPUType::CPU_NAME, Type> a, vector_t<CPUType::CPU_NAME, Type> b); + +template <> +CPU_ATTR inline vi multiply<int8_t>(vi a, vi b) { + auto even = mullo_epi16(a, b); + auto odd = mullo_epi16(srli_epi16(a, 8), srli_epi16(b, 8)); + return or_si(slli_epi16(odd, 8), srli_epi16(slli_epi16(even, 8), 8)); +} + +template <> +CPU_ATTR inline vi multiply<int16_t>(vi a, vi b) { + return mullo_epi16(a, b); +} + +template <> +CPU_ATTR inline vi multiply<int>(vi a, vi b) { +#if defined(THIS_IS_SSE2) + auto even = mul_epu32(a, b); + auto odd = mul_epu32(_mm_srli_si128(a, 4), _mm_srli_si128(b, 4)); + return unpacklo_epi32(shuffle_epi32(even, 0x8 /* = 0 0 2 0 */), shuffle_epi32(odd, 0x8 /* = 0 0 2 0 */)); +#elif defined(THIS_IS_AVX2) + return _mm256_mullo_epi32(a, b); +#else + return _mm512_mullo_epi32(a, b); +#endif +} + +template <> +CPU_ATTR inline vf multiply<float>(vf a, vf b) { + return mul_ps(a, b); +} + +template <> +CPU_ATTR inline vd multiply<double>(vd a, vd b) { + return mul_pd(a, b); +} + +/* * Floor */ CPU_ATTR static inline vf floor(vf input) { diff --git a/test/kernels/multiply_test.cc b/test/kernels/multiply_test.cc new file mode 100644 index 0000000..9673e89 --- /dev/null +++ b/test/kernels/multiply_test.cc @@ -0,0 +1,64 @@ +#include "test/test.h" +#include "aligned.h" +#include "kernels.h" + +#include <numeric> + +namespace intgemm { + +template <CPUType CPUType_, typename Type_> +void kernel_multiply_test() { + if (kCPU < CPUType_) + return; + + using vec_t = vector_t<CPUType_, Type_>; + constexpr static auto VECTOR_LENGTH = sizeof(vec_t) / sizeof(Type_); + + AlignedVector<Type_> input1(VECTOR_LENGTH); + AlignedVector<Type_> input2(VECTOR_LENGTH); + AlignedVector<Type_> output(VECTOR_LENGTH); + + std::iota(input1.begin(), input1.end(), -int(VECTOR_LENGTH / 2)); + std::iota(input2.begin(), input2.end(), -int(VECTOR_LENGTH / 3)); + + *output.template as<vec_t>() = kernels::multiply<Type_>(*input1.template as<vec_t>(), *input2.template as<vec_t>()); + for (auto i = 0; i < output.size(); ++i) + CHECK(output[i] == Type_(input1[i] * input2[i])); +} + +template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int8_t>(); +template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int16_t>(); +template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int>(); +template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, float>(); +template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, double>(); +KERNEL_TEST_CASE("multiply/int8 SSE2") { return kernel_multiply_test<CPUType::SSE2, int8_t>(); } +KERNEL_TEST_CASE("multiply/int16 SSE2") { return kernel_multiply_test<CPUType::SSE2, int16_t>(); } +KERNEL_TEST_CASE("multiply/int SSE2") { return kernel_multiply_test<CPUType::SSE2, int>(); } +KERNEL_TEST_CASE("multiply/float SSE2") { return kernel_multiply_test<CPUType::SSE2, float>(); } +KERNEL_TEST_CASE("multiply/double SSE2") { return kernel_multiply_test<CPUType::SSE2, double>(); } + +template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int8_t>(); +template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int16_t>(); +template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int>(); +template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, float>(); +template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, double>(); +KERNEL_TEST_CASE("multiply/int8 AVX2") { return kernel_multiply_test<CPUType::AVX2, int8_t>(); } +KERNEL_TEST_CASE("multiply/int16 AVX2") { return kernel_multiply_test<CPUType::AVX2, int16_t>(); } +KERNEL_TEST_CASE("multiply/int AVX2") { return kernel_multiply_test<CPUType::AVX2, int>(); } +KERNEL_TEST_CASE("multiply/float AVX2") { return kernel_multiply_test<CPUType::AVX2, float>(); } +KERNEL_TEST_CASE("multiply/double AVX2") { return kernel_multiply_test<CPUType::AVX2, double>(); } + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 +template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int8_t>(); +template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int16_t>(); +template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int>(); +template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, float>(); +template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, double>(); +KERNEL_TEST_CASE("multiply/int8 AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int8_t>(); } +KERNEL_TEST_CASE("multiply/int16 AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int16_t>(); } +KERNEL_TEST_CASE("multiply/int AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int>(); } +KERNEL_TEST_CASE("multiply/float AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, float>(); } +KERNEL_TEST_CASE("multiply/double AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, double>(); } +#endif + +} |