Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-22 18:20:25 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-22 18:28:00 +0300
commitedabfc96e5576479e7f88b4c6bfee75c7dfda9bd (patch)
treec95d210b1e9b8402c18f9ca4a6381ad99204dd02 /kernels/implementations.inl
parent721f4802464431dfecbc7c4bed68850f81b7af70 (diff)
Add multiply (elemwise) kernel
Diffstat (limited to 'kernels/implementations.inl')
-rw-r--r--kernels/implementations.inl41
1 files changed, 41 insertions, 0 deletions
diff --git a/kernels/implementations.inl b/kernels/implementations.inl
index fd46390..e2565b3 100644
--- a/kernels/implementations.inl
+++ b/kernels/implementations.inl
@@ -142,6 +142,47 @@ CPU_ATTR inline vd relu<double>(vd input) {
}
/*
+ * Multiply (elemwise)
+ */
+template <typename Type>
+CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> multiply(vector_t<CPUType::CPU_NAME, Type> a, vector_t<CPUType::CPU_NAME, Type> b);
+
+template <>
+CPU_ATTR inline vi multiply<int8_t>(vi a, vi b) {
+ auto even = mullo_epi16(a, b);
+ auto odd = mullo_epi16(srli_epi16(a, 8), srli_epi16(b, 8));
+ return or_si(slli_epi16(odd, 8), srli_epi16(slli_epi16(even, 8), 8));
+}
+
+template <>
+CPU_ATTR inline vi multiply<int16_t>(vi a, vi b) {
+ return mullo_epi16(a, b);
+}
+
+template <>
+CPU_ATTR inline vi multiply<int>(vi a, vi b) {
+#if defined(THIS_IS_SSE2)
+ auto even = mul_epu32(a, b);
+ auto odd = mul_epu32(_mm_srli_si128(a, 4), _mm_srli_si128(b, 4));
+ return unpacklo_epi32(shuffle_epi32(even, 0x8 /* = 0 0 2 0 */), shuffle_epi32(odd, 0x8 /* = 0 0 2 0 */));
+#elif defined(THIS_IS_AVX2)
+ return _mm256_mullo_epi32(a, b);
+#else
+ return _mm512_mullo_epi32(a, b);
+#endif
+}
+
+template <>
+CPU_ATTR inline vf multiply<float>(vf a, vf b) {
+ return mul_ps(a, b);
+}
+
+template <>
+CPU_ATTR inline vd multiply<double>(vd a, vd b) {
+ return mul_pd(a, b);
+}
+
+/*
* Floor
*/
CPU_ATTR static inline vf floor(vf input) {