Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-09-14 13:38:43 +0300
committerKenneth Heafield <github@kheafield.com>2020-09-14 13:38:43 +0300
commit9c36b961a003f2a230fa2934e766bffc98d8afb9 (patch)
treefa3b84af6c869e293729f2e6ecd83b3f470986e1
parent02f671cf537fdbc818cf8111d1d9e557a8650d7a (diff)
Use template arguments for slli and friends
Remove multiply_sat kernel for now
-rw-r--r--CMakeLists.txt1
-rw-r--r--intgemm/intrinsics.h48
-rw-r--r--intgemm/kernels/implementations.inl30
-rw-r--r--test/kernels/multiply_sat_test.cc53
4 files changed, 26 insertions, 106 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d0f8f07..d1885f5 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -84,7 +84,6 @@ add_executable(tests
test/kernels/downcast_test.cc
test/kernels/exp_test.cc
test/kernels/floor_test.cc
- test/kernels/multiply_sat_test.cc
test/kernels/multiply_test.cc
test/kernels/quantize_test.cc
test/kernels/relu_test.cc
diff --git a/intgemm/intrinsics.h b/intgemm/intrinsics.h
index 03cedd2..480f421 100644
--- a/intgemm/intrinsics.h
+++ b/intgemm/intrinsics.h
@@ -161,17 +161,17 @@ template <> INTGEMM_SSE2 inline __m128i setzero_si<__m128i>() {
INTGEMM_SSSE3 static inline __m128i sign_epi8(__m128i first, __m128i second) {
return _mm_sign_epi8(first, second);
}
-INTGEMM_SSE2 static inline __m128i slli_epi16(__m128i a, int8_t b) {
- return _mm_slli_epi16(a, b);
+template <int imm8> INTGEMM_SSE2 static inline __m128i slli_epi16(__m128i a) {
+ return _mm_slli_epi16(a, imm8);
}
-INTGEMM_SSE2 static inline __m128i srai_epi16(__m128i a, int8_t b) {
- return _mm_srai_epi16(a, b);
+template <int imm8> INTGEMM_SSE2 static inline __m128i srai_epi16(__m128i a) {
+ return _mm_srai_epi16(a, imm8);
}
-INTGEMM_SSE2 static inline __m128i srai_epi32(__m128i a, int8_t b) {
- return _mm_srai_epi32(a, b);
+template <int imm8> INTGEMM_SSE2 static inline __m128i srai_epi32(__m128i a) {
+ return _mm_srai_epi32(a, imm8);
}
-INTGEMM_SSE2 static inline __m128i srli_epi16(__m128i a, int8_t b) {
- return _mm_srli_epi16(a, b);
+template <int imm8> INTGEMM_SSE2 static inline __m128i srli_epi16(__m128i a) {
+ return _mm_srli_epi16(a, imm8);
}
INTGEMM_SSE2 static inline void storeu_ps(float* mem_addr, __m128 a) {
_mm_storeu_ps(mem_addr, a);
@@ -342,17 +342,17 @@ template <> INTGEMM_AVX2 inline __m256i setzero_si<__m256i>() {
INTGEMM_AVX2 static inline __m256i sign_epi8(__m256i first, __m256i second) {
return _mm256_sign_epi8(first, second);
}
-INTGEMM_AVX2 static inline __m256i slli_epi16(__m256i a, int8_t b) {
- return _mm256_slli_epi16(a, b);
+template <int imm8> INTGEMM_AVX2 static inline __m256i slli_epi16(__m256i a) {
+ return _mm256_slli_epi16(a, imm8);
}
-INTGEMM_AVX2 static inline __m256i srai_epi16(__m256i a, int8_t b) {
- return _mm256_srai_epi16(a, b);
+template <int imm8> INTGEMM_AVX2 static inline __m256i srai_epi16(__m256i a) {
+ return _mm256_srai_epi16(a, imm8);
}
-INTGEMM_AVX2 static inline __m256i srai_epi32(__m256i a, int8_t b) {
- return _mm256_srai_epi32(a, b);
+template <int imm8> INTGEMM_AVX2 static inline __m256i srai_epi32(__m256i a) {
+ return _mm256_srai_epi32(a, imm8);
}
-INTGEMM_AVX2 static inline __m256i srli_epi16(__m256i a, int8_t b) {
- return _mm256_srli_epi16(a, b);
+template <int imm8> INTGEMM_AVX2 static inline __m256i srli_epi16(__m256i a) {
+ return _mm256_srli_epi16(a, imm8);
}
INTGEMM_AVX2 static inline void storeu_ps(float* mem_addr, __m256 a) {
_mm256_storeu_ps(mem_addr, a);
@@ -539,17 +539,17 @@ template <> INTGEMM_AVX512BW inline __m512 load_ps<__m512>(const float* from) {
/*
* Missing sign_epi8
*/
-INTGEMM_AVX512BW static inline __m512i slli_epi16(__m512i a, int8_t b) {
- return _mm512_slli_epi16(a, b);
+template <int imm8> INTGEMM_AVX512BW static inline __m512i slli_epi16(__m512i a) {
+ return _mm512_slli_epi16(a, imm8);
}
-INTGEMM_AVX512BW static inline __m512i srai_epi16(__m512i a, int8_t b) {
- return _mm512_srai_epi16(a, b);
+template <int imm8> INTGEMM_AVX512BW static inline __m512i srai_epi16(__m512i a) {
+ return _mm512_srai_epi16(a, imm8);
}
-INTGEMM_AVX512BW static inline __m512i srai_epi32(__m512i a, int8_t b) {
- return _mm512_srai_epi32(a, b);
+template <int imm8> INTGEMM_AVX512BW static inline __m512i srai_epi32(__m512i a) {
+ return _mm512_srai_epi32(a, imm8);
}
-INTGEMM_AVX512BW static inline __m512i srli_epi16(__m512i a, int8_t b) {
- return _mm512_srli_epi16(a, b);
+template <int imm8> INTGEMM_AVX512BW static inline __m512i srli_epi16(__m512i a) {
+ return _mm512_srli_epi16(a, imm8);
}
INTGEMM_AVX512BW static inline void storeu_ps(float* mem_addr, __m512 a) {
_mm512_storeu_ps(mem_addr, a);
diff --git a/intgemm/kernels/implementations.inl b/intgemm/kernels/implementations.inl
index 2ec9f1f..6edfafd 100644
--- a/intgemm/kernels/implementations.inl
+++ b/intgemm/kernels/implementations.inl
@@ -145,8 +145,8 @@ CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> multiply(vector_t<CPUTy
template <>
CPU_ATTR inline vi multiply<int8_t>(vi a, vi b) {
auto even = mullo_epi16(a, b);
- auto odd = mullo_epi16(srli_epi16(a, 8), srli_epi16(b, 8));
- return or_si(slli_epi16(odd, 8), srli_epi16(slli_epi16(even, 8), 8));
+ auto odd = mullo_epi16(srli_epi16<8>(a), srli_epi16<8>(b));
+ return or_si(slli_epi16<8>(odd), srli_epi16<8>(slli_epi16<8>(even)));
}
template <>
@@ -296,32 +296,6 @@ CPU_ATTR static inline vi bitwise_not(vi v) {
}
/*
- * Multiply with saturation (elemwise)
- */
-template <typename Type>
-CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> multiply_sat(vector_t<CPUType::CPU_NAME, Type> a, vector_t<CPUType::CPU_NAME, Type> b, uint8_t right_shift);
-
-template <>
-CPU_ATTR inline vi multiply_sat<int8_t>(vi a, vi b, uint8_t right_shift) {
- auto upcasted_a = upcast8to16(a);
- auto upcasted_b = upcast8to16(b);
- auto low = srai_epi16(multiply<int16_t>(upcasted_a.first, upcasted_b.first), right_shift);
- auto hi = srai_epi16(multiply<int16_t>(upcasted_a.second, upcasted_b.second), right_shift);
-
- return downcast16to8(low, hi);
-}
-
-template <>
-CPU_ATTR inline vi multiply_sat<int16_t>(vi a, vi b, uint8_t right_shift) {
- auto upcasted_a = upcast16to32(a);
- auto upcasted_b = upcast16to32(b);
- auto low = srai_epi32(multiply<int32_t>(upcasted_a.first, upcasted_b.first), right_shift);
- auto hi = srai_epi32(multiply<int32_t>(upcasted_a.second, upcasted_b.second), right_shift);
-
- return downcast32to16(low, hi);
-}
-
-/*
* Floor
*/
CPU_ATTR static inline vf floor(vf input) {
diff --git a/test/kernels/multiply_sat_test.cc b/test/kernels/multiply_sat_test.cc
deleted file mode 100644
index 6d8ed22..0000000
--- a/test/kernels/multiply_sat_test.cc
+++ /dev/null
@@ -1,53 +0,0 @@
-#include "../test.h"
-#include "../../intgemm/aligned.h"
-#include "../../intgemm/kernels.h"
-
-#include <cstdint>
-#include <cstddef>
-#include <numeric>
-
-namespace intgemm {
-
-template <CPUType CPUType_, typename Type_>
-void kernel_multiply_sat_test() {
- if (kCPU < CPUType_)
- return;
-
- using vec_t = vector_t<CPUType_, Type_>;
- constexpr int VECTOR_LENGTH = sizeof(vec_t) / sizeof(Type_);
-
- AlignedVector<Type_> input1(VECTOR_LENGTH);
- AlignedVector<Type_> input2(VECTOR_LENGTH);
- AlignedVector<Type_> output(VECTOR_LENGTH);
-
- std::iota(input1.begin(), input1.end(), static_cast<Type_>(-VECTOR_LENGTH / 2));
- std::iota(input2.begin(), input2.end(), static_cast<Type_>(-VECTOR_LENGTH / 3));
-
- // TODO: try all shifts. The shift must be an immediate.
- int8_t shift = 1;
- *output.template as<vec_t>() = kernels::multiply_sat<Type_>(*input1.template as<vec_t>(), *input2.template as<vec_t>(), shift);
- for (std::size_t i = 0; i < output.size(); ++i) {
- auto ref = (int64_t(input1[i]) * input2[i]) >> shift;
- auto ref_sat = Type_(std::min<int64_t>(std::numeric_limits<Type_>::max(), std::max<int64_t>(std::numeric_limits<Type_>::min(), ref)));
- CHECK(output[i] == ref_sat);
- }
-}
-
-template INTGEMM_SSE2 void kernel_multiply_sat_test<CPUType::SSE2, int8_t>();
-template INTGEMM_SSE2 void kernel_multiply_sat_test<CPUType::SSE2, int16_t>();
-KERNEL_TEST_CASE("multiply_sat/int8 SSE2") { return kernel_multiply_sat_test<CPUType::SSE2, int8_t>(); }
-KERNEL_TEST_CASE("multiply_sat/int16 SSE2") { return kernel_multiply_sat_test<CPUType::SSE2, int16_t>(); }
-
-template INTGEMM_AVX2 void kernel_multiply_sat_test<CPUType::AVX2, int8_t>();
-template INTGEMM_AVX2 void kernel_multiply_sat_test<CPUType::AVX2, int16_t>();
-KERNEL_TEST_CASE("multiply_sat/int8 AVX2") { return kernel_multiply_sat_test<CPUType::AVX2, int8_t>(); }
-KERNEL_TEST_CASE("multiply_sat/int16 AVX2") { return kernel_multiply_sat_test<CPUType::AVX2, int16_t>(); }
-
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
-template INTGEMM_AVX512BW void kernel_multiply_sat_test<CPUType::AVX512BW, int8_t>();
-template INTGEMM_AVX512BW void kernel_multiply_sat_test<CPUType::AVX512BW, int16_t>();
-KERNEL_TEST_CASE("multiply_sat/int8 AVX512BW") { return kernel_multiply_sat_test<CPUType::AVX512BW, int8_t>(); }
-KERNEL_TEST_CASE("multiply_sat/int16 AVX512BW") { return kernel_multiply_sat_test<CPUType::AVX512BW, int16_t>(); }
-#endif
-
-}