diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-09-14 13:38:43 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-09-14 13:38:43 +0300 |
commit | 9c36b961a003f2a230fa2934e766bffc98d8afb9 (patch) | |
tree | fa3b84af6c869e293729f2e6ecd83b3f470986e1 | |
parent | 02f671cf537fdbc818cf8111d1d9e557a8650d7a (diff) |
Use template arguments for slli and friends
Remove multiply_sat kernel for now
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | intgemm/intrinsics.h | 48 | ||||
-rw-r--r-- | intgemm/kernels/implementations.inl | 30 | ||||
-rw-r--r-- | test/kernels/multiply_sat_test.cc | 53 |
4 files changed, 26 insertions, 106 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index d0f8f07..d1885f5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,7 +84,6 @@ add_executable(tests test/kernels/downcast_test.cc test/kernels/exp_test.cc test/kernels/floor_test.cc - test/kernels/multiply_sat_test.cc test/kernels/multiply_test.cc test/kernels/quantize_test.cc test/kernels/relu_test.cc diff --git a/intgemm/intrinsics.h b/intgemm/intrinsics.h index 03cedd2..480f421 100644 --- a/intgemm/intrinsics.h +++ b/intgemm/intrinsics.h @@ -161,17 +161,17 @@ template <> INTGEMM_SSE2 inline __m128i setzero_si<__m128i>() { INTGEMM_SSSE3 static inline __m128i sign_epi8(__m128i first, __m128i second) { return _mm_sign_epi8(first, second); } -INTGEMM_SSE2 static inline __m128i slli_epi16(__m128i a, int8_t b) { - return _mm_slli_epi16(a, b); +template <int imm8> INTGEMM_SSE2 static inline __m128i slli_epi16(__m128i a) { + return _mm_slli_epi16(a, imm8); } -INTGEMM_SSE2 static inline __m128i srai_epi16(__m128i a, int8_t b) { - return _mm_srai_epi16(a, b); +template <int imm8> INTGEMM_SSE2 static inline __m128i srai_epi16(__m128i a) { + return _mm_srai_epi16(a, imm8); } -INTGEMM_SSE2 static inline __m128i srai_epi32(__m128i a, int8_t b) { - return _mm_srai_epi32(a, b); +template <int imm8> INTGEMM_SSE2 static inline __m128i srai_epi32(__m128i a) { + return _mm_srai_epi32(a, imm8); } -INTGEMM_SSE2 static inline __m128i srli_epi16(__m128i a, int8_t b) { - return _mm_srli_epi16(a, b); +template <int imm8> INTGEMM_SSE2 static inline __m128i srli_epi16(__m128i a) { + return _mm_srli_epi16(a, imm8); } INTGEMM_SSE2 static inline void storeu_ps(float* mem_addr, __m128 a) { _mm_storeu_ps(mem_addr, a); @@ -342,17 +342,17 @@ template <> INTGEMM_AVX2 inline __m256i setzero_si<__m256i>() { INTGEMM_AVX2 static inline __m256i sign_epi8(__m256i first, __m256i second) { return _mm256_sign_epi8(first, second); } -INTGEMM_AVX2 static inline __m256i slli_epi16(__m256i a, int8_t b) { - return _mm256_slli_epi16(a, b); +template <int imm8> INTGEMM_AVX2 static inline __m256i slli_epi16(__m256i a) { + return _mm256_slli_epi16(a, imm8); } -INTGEMM_AVX2 static inline __m256i srai_epi16(__m256i a, int8_t b) { - return _mm256_srai_epi16(a, b); +template <int imm8> INTGEMM_AVX2 static inline __m256i srai_epi16(__m256i a) { + return _mm256_srai_epi16(a, imm8); } -INTGEMM_AVX2 static inline __m256i srai_epi32(__m256i a, int8_t b) { - return _mm256_srai_epi32(a, b); +template <int imm8> INTGEMM_AVX2 static inline __m256i srai_epi32(__m256i a) { + return _mm256_srai_epi32(a, imm8); } -INTGEMM_AVX2 static inline __m256i srli_epi16(__m256i a, int8_t b) { - return _mm256_srli_epi16(a, b); +template <int imm8> INTGEMM_AVX2 static inline __m256i srli_epi16(__m256i a) { + return _mm256_srli_epi16(a, imm8); } INTGEMM_AVX2 static inline void storeu_ps(float* mem_addr, __m256 a) { _mm256_storeu_ps(mem_addr, a); @@ -539,17 +539,17 @@ template <> INTGEMM_AVX512BW inline __m512 load_ps<__m512>(const float* from) { /* * Missing sign_epi8 */ -INTGEMM_AVX512BW static inline __m512i slli_epi16(__m512i a, int8_t b) { - return _mm512_slli_epi16(a, b); +template <int imm8> INTGEMM_AVX512BW static inline __m512i slli_epi16(__m512i a) { + return _mm512_slli_epi16(a, imm8); } -INTGEMM_AVX512BW static inline __m512i srai_epi16(__m512i a, int8_t b) { - return _mm512_srai_epi16(a, b); +template <int imm8> INTGEMM_AVX512BW static inline __m512i srai_epi16(__m512i a) { + return _mm512_srai_epi16(a, imm8); } -INTGEMM_AVX512BW static inline __m512i srai_epi32(__m512i a, int8_t b) { - return _mm512_srai_epi32(a, b); +template <int imm8> INTGEMM_AVX512BW static inline __m512i srai_epi32(__m512i a) { + return _mm512_srai_epi32(a, imm8); } -INTGEMM_AVX512BW static inline __m512i srli_epi16(__m512i a, int8_t b) { - return _mm512_srli_epi16(a, b); +template <int imm8> INTGEMM_AVX512BW static inline __m512i srli_epi16(__m512i a) { + return _mm512_srli_epi16(a, imm8); } INTGEMM_AVX512BW static inline void storeu_ps(float* mem_addr, __m512 a) { _mm512_storeu_ps(mem_addr, a); diff --git a/intgemm/kernels/implementations.inl b/intgemm/kernels/implementations.inl index 2ec9f1f..6edfafd 100644 --- a/intgemm/kernels/implementations.inl +++ b/intgemm/kernels/implementations.inl @@ -145,8 +145,8 @@ CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> multiply(vector_t<CPUTy template <> CPU_ATTR inline vi multiply<int8_t>(vi a, vi b) { auto even = mullo_epi16(a, b); - auto odd = mullo_epi16(srli_epi16(a, 8), srli_epi16(b, 8)); - return or_si(slli_epi16(odd, 8), srli_epi16(slli_epi16(even, 8), 8)); + auto odd = mullo_epi16(srli_epi16<8>(a), srli_epi16<8>(b)); + return or_si(slli_epi16<8>(odd), srli_epi16<8>(slli_epi16<8>(even))); } template <> @@ -296,32 +296,6 @@ CPU_ATTR static inline vi bitwise_not(vi v) { } /* - * Multiply with saturation (elemwise) - */ -template <typename Type> -CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> multiply_sat(vector_t<CPUType::CPU_NAME, Type> a, vector_t<CPUType::CPU_NAME, Type> b, uint8_t right_shift); - -template <> -CPU_ATTR inline vi multiply_sat<int8_t>(vi a, vi b, uint8_t right_shift) { - auto upcasted_a = upcast8to16(a); - auto upcasted_b = upcast8to16(b); - auto low = srai_epi16(multiply<int16_t>(upcasted_a.first, upcasted_b.first), right_shift); - auto hi = srai_epi16(multiply<int16_t>(upcasted_a.second, upcasted_b.second), right_shift); - - return downcast16to8(low, hi); -} - -template <> -CPU_ATTR inline vi multiply_sat<int16_t>(vi a, vi b, uint8_t right_shift) { - auto upcasted_a = upcast16to32(a); - auto upcasted_b = upcast16to32(b); - auto low = srai_epi32(multiply<int32_t>(upcasted_a.first, upcasted_b.first), right_shift); - auto hi = srai_epi32(multiply<int32_t>(upcasted_a.second, upcasted_b.second), right_shift); - - return downcast32to16(low, hi); -} - -/* * Floor */ CPU_ATTR static inline vf floor(vf input) { diff --git a/test/kernels/multiply_sat_test.cc b/test/kernels/multiply_sat_test.cc deleted file mode 100644 index 6d8ed22..0000000 --- a/test/kernels/multiply_sat_test.cc +++ /dev/null @@ -1,53 +0,0 @@ -#include "../test.h" -#include "../../intgemm/aligned.h" -#include "../../intgemm/kernels.h" - -#include <cstdint> -#include <cstddef> -#include <numeric> - -namespace intgemm { - -template <CPUType CPUType_, typename Type_> -void kernel_multiply_sat_test() { - if (kCPU < CPUType_) - return; - - using vec_t = vector_t<CPUType_, Type_>; - constexpr int VECTOR_LENGTH = sizeof(vec_t) / sizeof(Type_); - - AlignedVector<Type_> input1(VECTOR_LENGTH); - AlignedVector<Type_> input2(VECTOR_LENGTH); - AlignedVector<Type_> output(VECTOR_LENGTH); - - std::iota(input1.begin(), input1.end(), static_cast<Type_>(-VECTOR_LENGTH / 2)); - std::iota(input2.begin(), input2.end(), static_cast<Type_>(-VECTOR_LENGTH / 3)); - - // TODO: try all shifts. The shift must be an immediate. - int8_t shift = 1; - *output.template as<vec_t>() = kernels::multiply_sat<Type_>(*input1.template as<vec_t>(), *input2.template as<vec_t>(), shift); - for (std::size_t i = 0; i < output.size(); ++i) { - auto ref = (int64_t(input1[i]) * input2[i]) >> shift; - auto ref_sat = Type_(std::min<int64_t>(std::numeric_limits<Type_>::max(), std::max<int64_t>(std::numeric_limits<Type_>::min(), ref))); - CHECK(output[i] == ref_sat); - } -} - -template INTGEMM_SSE2 void kernel_multiply_sat_test<CPUType::SSE2, int8_t>(); -template INTGEMM_SSE2 void kernel_multiply_sat_test<CPUType::SSE2, int16_t>(); -KERNEL_TEST_CASE("multiply_sat/int8 SSE2") { return kernel_multiply_sat_test<CPUType::SSE2, int8_t>(); } -KERNEL_TEST_CASE("multiply_sat/int16 SSE2") { return kernel_multiply_sat_test<CPUType::SSE2, int16_t>(); } - -template INTGEMM_AVX2 void kernel_multiply_sat_test<CPUType::AVX2, int8_t>(); -template INTGEMM_AVX2 void kernel_multiply_sat_test<CPUType::AVX2, int16_t>(); -KERNEL_TEST_CASE("multiply_sat/int8 AVX2") { return kernel_multiply_sat_test<CPUType::AVX2, int8_t>(); } -KERNEL_TEST_CASE("multiply_sat/int16 AVX2") { return kernel_multiply_sat_test<CPUType::AVX2, int16_t>(); } - -#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW -template INTGEMM_AVX512BW void kernel_multiply_sat_test<CPUType::AVX512BW, int8_t>(); -template INTGEMM_AVX512BW void kernel_multiply_sat_test<CPUType::AVX512BW, int16_t>(); -KERNEL_TEST_CASE("multiply_sat/int8 AVX512BW") { return kernel_multiply_sat_test<CPUType::AVX512BW, int8_t>(); } -KERNEL_TEST_CASE("multiply_sat/int16 AVX512BW") { return kernel_multiply_sat_test<CPUType::AVX512BW, int16_t>(); } -#endif - -} |