Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-22 18:20:25 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-22 18:28:00 +0300
commitedabfc96e5576479e7f88b4c6bfee75c7dfda9bd (patch)
treec95d210b1e9b8402c18f9ca4a6381ad99204dd02
parent721f4802464431dfecbc7c4bed68850f81b7af70 (diff)
Add multiply (elemwise) kernel
-rw-r--r--CMakeLists.txt1
-rw-r--r--intrinsics.h63
-rw-r--r--kernels/implementations.inl41
-rw-r--r--test/kernels/multiply_test.cc64
4 files changed, 169 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2a15175..59ef89b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -49,6 +49,7 @@ add_executable(tests
test/kernels/add_bias_test.cc
test/kernels/exp_test.cc
test/kernels/floor_test.cc
+ test/kernels/multiply_test.cc
test/kernels/quantize_test.cc
test/kernels/relu_test.cc
test/kernels/sigmoid_test.cc
diff --git a/intrinsics.h b/intrinsics.h
index f6c63a9..d21565e 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -97,12 +97,21 @@ INTGEMM_SSE2 static inline __m128 max_ps(__m128 first, __m128 second) {
INTGEMM_SSE2 static inline __m128 min_ps(__m128 a, __m128 b) {
return _mm_min_ps(a, b);
}
+INTGEMM_SSE2 static inline __m128i mul_epu32(__m128i a, __m128i b) {
+ return _mm_mul_epu32(a, b);
+}
INTGEMM_SSE2 static inline __m128d mul_pd(__m128d a, __m128d b) {
return _mm_mul_pd(a, b);
}
INTGEMM_SSE2 static inline __m128 mul_ps(__m128 a, __m128 b) {
return _mm_mul_ps(a, b);
}
+INTGEMM_SSE2 static inline __m128i mullo_epi16(__m128i a, __m128i b) {
+ return _mm_mullo_epi16(a, b);
+}
+INTGEMM_SSE2 static inline __m128i or_si(__m128i a, __m128i b) {
+ return _mm_or_si128(a, b);
+}
template <> INTGEMM_SSE2 inline __m128i set1_epi8<__m128i>(int8_t to) {
return _mm_set1_epi8(to);
}
@@ -127,9 +136,18 @@ template <> INTGEMM_SSE2 inline __m128 setzero_ps<__m128>() {
template <> INTGEMM_SSE2 inline __m128i setzero_si<__m128i>() {
return _mm_setzero_si128();
}
+INTGEMM_SSE2 static inline __m128i shuffle_epi32(__m128i a, int imm8) {
+ return _mm_shuffle_epi32(a, imm8);
+}
INTGEMM_SSSE3 static inline __m128i sign_epi8(__m128i first, __m128i second) {
return _mm_sign_epi8(first, second);
}
+INTGEMM_SSE2 static inline __m128i slli_epi16(__m128i a, int8_t b) {
+ return _mm_slli_epi16(a, b);
+}
+INTGEMM_SSE2 static inline __m128i srli_epi16(__m128i a, int8_t b) {
+ return _mm_srli_epi16(a, b);
+}
INTGEMM_SSE2 static inline void storeu_ps(float* mem_addr, __m128 a) {
_mm_storeu_ps(mem_addr, a);
}
@@ -139,6 +157,9 @@ INTGEMM_SSE2 static inline __m128d sub_pd(__m128d a, __m128d b) {
INTGEMM_SSE2 static inline __m128 sub_ps(__m128 a, __m128 b) {
return _mm_sub_ps(a, b);
}
+INTGEMM_SSE2 static inline __m128i unpacklo_epi32(__m128i a, __m128i b) {
+ return _mm_unpacklo_epi32(a, b);
+}
/*
*
@@ -212,12 +233,21 @@ INTGEMM_AVX2 static inline __m256 max_ps(__m256 first, __m256 second) {
INTGEMM_AVX2 static inline __m256 min_ps(__m256 a, __m256 b) {
return _mm256_min_ps(a, b);
}
+INTGEMM_AVX2 static inline __m256i mul_epu32(__m256i a, __m256i b) {
+ return _mm256_mul_epu32(a, b);
+}
INTGEMM_AVX2 static inline __m256d mul_pd(__m256d a, __m256d b) {
return _mm256_mul_pd(a, b);
}
INTGEMM_AVX2 static inline __m256 mul_ps(__m256 a, __m256 b) {
return _mm256_mul_ps(a, b);
}
+INTGEMM_AVX2 static inline __m256i mullo_epi16(__m256i a, __m256i b) {
+ return _mm256_mullo_epi16(a, b);
+}
+INTGEMM_AVX2 static inline __m256i or_si(__m256i a, __m256i b) {
+ return _mm256_or_si256(a, b);
+}
template <> INTGEMM_AVX2 inline __m256i set1_epi8<__m256i>(int8_t to) {
return _mm256_set1_epi8(to);
}
@@ -242,9 +272,18 @@ template <> INTGEMM_AVX2 inline __m256 setzero_ps<__m256>() {
template <> INTGEMM_AVX2 inline __m256i setzero_si<__m256i>() {
return _mm256_setzero_si256();
}
+INTGEMM_AVX2 static inline __m256i shuffle_epi32(__m256i a, int imm8) {
+ return _mm256_shuffle_epi32(a, imm8);
+}
INTGEMM_AVX2 static inline __m256i sign_epi8(__m256i first, __m256i second) {
return _mm256_sign_epi8(first, second);
}
+INTGEMM_AVX2 static inline __m256i slli_epi16(__m256i a, int8_t b) {
+ return _mm256_slli_epi16(a, b);
+}
+INTGEMM_AVX2 static inline __m256i srli_epi16(__m256i a, int8_t b) {
+ return _mm256_srli_epi16(a, b);
+}
INTGEMM_AVX2 static inline void storeu_ps(float* mem_addr, __m256 a) {
_mm256_storeu_ps(mem_addr, a);
}
@@ -254,6 +293,9 @@ INTGEMM_AVX2 static inline __m256d sub_pd(__m256d a, __m256d b) {
INTGEMM_AVX2 static inline __m256 sub_ps(__m256 a, __m256 b) {
return _mm256_sub_ps(a, b);
}
+INTGEMM_AVX2 static inline __m256i unpacklo_epi32(__m256i a, __m256i b) {
+ return _mm256_unpacklo_epi32(a, b);
+}
/*
*
@@ -329,12 +371,21 @@ INTGEMM_AVX512BW static inline __m512 max_ps(__m512 first, __m512 second) {
INTGEMM_AVX512BW static inline __m512 min_ps(__m512 a, __m512 b) {
return _mm512_min_ps(a, b);
}
+INTGEMM_AVX512BW static inline __m512i mul_epu32(__m512i a, __m512i b) {
+ return _mm512_mul_epu32(a, b);
+}
INTGEMM_AVX512BW static inline __m512d mul_pd(__m512d a, __m512d b) {
return _mm512_mul_pd(a, b);
}
INTGEMM_AVX512BW static inline __m512 mul_ps(__m512 a, __m512 b) {
return _mm512_mul_ps(a, b);
}
+INTGEMM_AVX512BW static inline __m512i mullo_epi16(__m512i a, __m512i b) {
+ return _mm512_mullo_epi16(a, b);
+}
+INTGEMM_AVX512BW static inline __m512i or_si(__m512i a, __m512i b) {
+ return _mm512_or_si512(a, b);
+}
template <> inline INTGEMM_AVX512BW __m512i set1_epi8<__m512i>(int8_t to) {
return _mm512_set1_epi8(to);
}
@@ -362,6 +413,15 @@ template <> INTGEMM_AVX512BW inline __m512i setzero_si<__m512i>() {
/*
* Missing sign_epi8
*/
+INTGEMM_AVX512BW static inline __m512i shuffle_epi32(__m512i a, _MM_PERM_ENUM imm8) {
+ return _mm512_shuffle_epi32(a, imm8);
+}
+INTGEMM_AVX512BW static inline __m512i slli_epi16(__m512i a, int8_t b) {
+ return _mm512_slli_epi16(a, b);
+}
+INTGEMM_AVX512BW static inline __m512i srli_epi16(__m512i a, int8_t b) {
+ return _mm512_srli_epi16(a, b);
+}
INTGEMM_AVX512BW static inline void storeu_ps(float* mem_addr, __m512 a) {
_mm512_storeu_ps(mem_addr, a);
}
@@ -371,6 +431,9 @@ INTGEMM_AVX512BW static inline __m512d sub_pd(__m512d a, __m512d b) {
INTGEMM_AVX512BW static inline __m512 sub_ps(__m512 a, __m512 b) {
return _mm512_sub_ps(a, b);
}
+INTGEMM_AVX512BW static inline __m512i unpacklo_epi32(__m512i a, __m512i b) {
+ return _mm512_unpacklo_epi32(a, b);
+}
#endif
diff --git a/kernels/implementations.inl b/kernels/implementations.inl
index fd46390..e2565b3 100644
--- a/kernels/implementations.inl
+++ b/kernels/implementations.inl
@@ -142,6 +142,47 @@ CPU_ATTR inline vd relu<double>(vd input) {
}
/*
+ * Multiply (elemwise)
+ */
+template <typename Type>
+CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> multiply(vector_t<CPUType::CPU_NAME, Type> a, vector_t<CPUType::CPU_NAME, Type> b);
+
+template <>
+CPU_ATTR inline vi multiply<int8_t>(vi a, vi b) {
+ auto even = mullo_epi16(a, b);
+ auto odd = mullo_epi16(srli_epi16(a, 8), srli_epi16(b, 8));
+ return or_si(slli_epi16(odd, 8), srli_epi16(slli_epi16(even, 8), 8));
+}
+
+template <>
+CPU_ATTR inline vi multiply<int16_t>(vi a, vi b) {
+ return mullo_epi16(a, b);
+}
+
+template <>
+CPU_ATTR inline vi multiply<int>(vi a, vi b) {
+#if defined(THIS_IS_SSE2)
+ auto even = mul_epu32(a, b);
+ auto odd = mul_epu32(_mm_srli_si128(a, 4), _mm_srli_si128(b, 4));
+ return unpacklo_epi32(shuffle_epi32(even, 0x8 /* = 0 0 2 0 */), shuffle_epi32(odd, 0x8 /* = 0 0 2 0 */));
+#elif defined(THIS_IS_AVX2)
+ return _mm256_mullo_epi32(a, b);
+#else
+ return _mm512_mullo_epi32(a, b);
+#endif
+}
+
+template <>
+CPU_ATTR inline vf multiply<float>(vf a, vf b) {
+ return mul_ps(a, b);
+}
+
+template <>
+CPU_ATTR inline vd multiply<double>(vd a, vd b) {
+ return mul_pd(a, b);
+}
+
+/*
* Floor
*/
CPU_ATTR static inline vf floor(vf input) {
diff --git a/test/kernels/multiply_test.cc b/test/kernels/multiply_test.cc
new file mode 100644
index 0000000..9673e89
--- /dev/null
+++ b/test/kernels/multiply_test.cc
@@ -0,0 +1,64 @@
+#include "test/test.h"
+#include "aligned.h"
+#include "kernels.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+template <CPUType CPUType_, typename Type_>
+void kernel_multiply_test() {
+ if (kCPU < CPUType_)
+ return;
+
+ using vec_t = vector_t<CPUType_, Type_>;
+ constexpr static auto VECTOR_LENGTH = sizeof(vec_t) / sizeof(Type_);
+
+ AlignedVector<Type_> input1(VECTOR_LENGTH);
+ AlignedVector<Type_> input2(VECTOR_LENGTH);
+ AlignedVector<Type_> output(VECTOR_LENGTH);
+
+ std::iota(input1.begin(), input1.end(), -int(VECTOR_LENGTH / 2));
+ std::iota(input2.begin(), input2.end(), -int(VECTOR_LENGTH / 3));
+
+ *output.template as<vec_t>() = kernels::multiply<Type_>(*input1.template as<vec_t>(), *input2.template as<vec_t>());
+ for (auto i = 0; i < output.size(); ++i)
+ CHECK(output[i] == Type_(input1[i] * input2[i]));
+}
+
+template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int8_t>();
+template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int16_t>();
+template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, int>();
+template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, float>();
+template INTGEMM_SSE2 void kernel_multiply_test<CPUType::SSE2, double>();
+KERNEL_TEST_CASE("multiply/int8 SSE2") { return kernel_multiply_test<CPUType::SSE2, int8_t>(); }
+KERNEL_TEST_CASE("multiply/int16 SSE2") { return kernel_multiply_test<CPUType::SSE2, int16_t>(); }
+KERNEL_TEST_CASE("multiply/int SSE2") { return kernel_multiply_test<CPUType::SSE2, int>(); }
+KERNEL_TEST_CASE("multiply/float SSE2") { return kernel_multiply_test<CPUType::SSE2, float>(); }
+KERNEL_TEST_CASE("multiply/double SSE2") { return kernel_multiply_test<CPUType::SSE2, double>(); }
+
+template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int8_t>();
+template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int16_t>();
+template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int>();
+template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, float>();
+template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, double>();
+KERNEL_TEST_CASE("multiply/int8 AVX2") { return kernel_multiply_test<CPUType::AVX2, int8_t>(); }
+KERNEL_TEST_CASE("multiply/int16 AVX2") { return kernel_multiply_test<CPUType::AVX2, int16_t>(); }
+KERNEL_TEST_CASE("multiply/int AVX2") { return kernel_multiply_test<CPUType::AVX2, int>(); }
+KERNEL_TEST_CASE("multiply/float AVX2") { return kernel_multiply_test<CPUType::AVX2, float>(); }
+KERNEL_TEST_CASE("multiply/double AVX2") { return kernel_multiply_test<CPUType::AVX2, double>(); }
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int8_t>();
+template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int16_t>();
+template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int>();
+template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, float>();
+template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, double>();
+KERNEL_TEST_CASE("multiply/int8 AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int8_t>(); }
+KERNEL_TEST_CASE("multiply/int16 AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int16_t>(); }
+KERNEL_TEST_CASE("multiply/int AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, int>(); }
+KERNEL_TEST_CASE("multiply/float AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, float>(); }
+KERNEL_TEST_CASE("multiply/double AVX512BW") { return kernel_multiply_test<CPUType::AVX512BW, double>(); }
+#endif
+
+}