diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-19 19:15:58 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-19 19:16:37 +0300 |
commit | a80efb933528ffbad8d17f7f7a915aeee1c3e0e7 (patch) | |
tree | cb1e654e72398622014528aba9acb0ff178dcf1f | |
parent | d02f3ce07214122f1f36f2fdd0379db8a9abd409 (diff) |
Add vec_traits for int8 and int16
-rw-r--r-- | CMakeLists.txt | 3 | ||||
-rw-r--r-- | intrinsics.h | 47 | ||||
-rw-r--r-- | kernels/implementations.inl | 83 | ||||
-rw-r--r-- | test/kernels/add_bias_test.cc | 26 | ||||
-rw-r--r-- | test/kernels/floor_test.cc (renamed from test/kernels/floor_ff_test.cc) | 16 | ||||
-rw-r--r-- | test/kernels/highway_test.cc | 48 | ||||
-rw-r--r-- | test/kernels/relu_test.cc | 14 | ||||
-rw-r--r-- | test/kernels/write_test.cc | 14 | ||||
-rw-r--r-- | vec_traits.h | 6 |
9 files changed, 174 insertions, 83 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index ff4048e..2a15175 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,8 +48,7 @@ add_executable(tests # Kernels tests test/kernels/add_bias_test.cc test/kernels/exp_test.cc - test/kernels/floor_ff_test.cc - test/kernels/highway_test.cc + test/kernels/floor_test.cc test/kernels/quantize_test.cc test/kernels/relu_test.cc test/kernels/sigmoid_test.cc diff --git a/intrinsics.h b/intrinsics.h index 92b06c9..f6c63a9 100644 --- a/intrinsics.h +++ b/intrinsics.h @@ -17,10 +17,12 @@ namespace intgemm { * templates. */ template <class Register> static inline Register loadu_ps(const float* mem_addr); +template <class Register> static inline Register set1_epi8(int8_t to); template <class Register> static inline Register set1_epi16(int16_t to); template <class Register> static inline Register set1_epi32(int32_t to); template <class Register> static inline Register set1_pd(double to); template <class Register> static inline Register set1_ps(float to); +template <class Register> static inline Register setzero_pd(); template <class Register> static inline Register setzero_ps(); template <class Register> static inline Register setzero_si(); @@ -32,6 +34,12 @@ template <class Register> static inline Register setzero_si(); INTGEMM_SSSE3 static inline __m128i abs_epi8(__m128i arg) { return _mm_abs_epi8(arg); } +INTGEMM_SSE2 static inline __m128i add_epi8(__m128i a, __m128i b) { + return _mm_add_epi8(a, b); +} +INTGEMM_SSE2 static inline __m128i add_epi16(__m128i a, __m128i b) { + return _mm_add_epi16(a, b); +} INTGEMM_SSE2 static inline __m128i add_epi32(__m128i first, __m128i second) { return _mm_add_epi32(first, second); } @@ -47,6 +55,9 @@ INTGEMM_SSE2 static inline __m128 add_ps(__m128 a, __m128 b) { INTGEMM_SSE2 static inline __m128 and_ps(__m128 first, __m128 second) { return _mm_and_ps(first, second); } +INTGEMM_SSE2 static inline __m128i and_si(__m128i a, __m128i b) { + return _mm_and_si128(a, b); +} INTGEMM_SSE2 static inline __m128 cvtepi32_ps(__m128i arg) { return _mm_cvtepi32_ps(arg); } @@ -92,6 +103,9 @@ INTGEMM_SSE2 static inline __m128d mul_pd(__m128d a, __m128d b) { INTGEMM_SSE2 static inline __m128 mul_ps(__m128 a, __m128 b) { return _mm_mul_ps(a, b); } +template <> INTGEMM_SSE2 inline __m128i set1_epi8<__m128i>(int8_t to) { + return _mm_set1_epi8(to); +} template <> INTGEMM_SSE2 inline __m128i set1_epi16<__m128i>(int16_t to) { return _mm_set1_epi16(to); } @@ -104,6 +118,9 @@ template <> INTGEMM_SSE2 inline __m128d set1_pd<__m128d>(double to) { template <> INTGEMM_SSE2 inline __m128 set1_ps<__m128>(float to) { return _mm_set1_ps(to); } +template <> INTGEMM_SSE2 inline __m128d setzero_pd<__m128d>() { + return _mm_setzero_pd(); +} template <> INTGEMM_SSE2 inline __m128 setzero_ps<__m128>() { return _mm_setzero_ps(); } @@ -131,6 +148,12 @@ INTGEMM_SSE2 static inline __m128 sub_ps(__m128 a, __m128 b) { INTGEMM_AVX2 static inline __m256i abs_epi8(__m256i arg) { return _mm256_abs_epi8(arg); } +INTGEMM_AVX2 static inline __m256i add_epi8(__m256i a, __m256i b) { + return _mm256_add_epi8(a, b); +} +INTGEMM_AVX2 static inline __m256i add_epi16(__m256i a, __m256i b) { + return _mm256_add_epi16(a, b); +} INTGEMM_AVX2 static inline __m256i add_epi32(__m256i first, __m256i second) { return _mm256_add_epi32(first, second); } @@ -146,6 +169,9 @@ INTGEMM_AVX2 static inline __m256 add_ps(__m256 a, __m256 b) { INTGEMM_AVX2 static inline __m256 and_ps(__m256 first, __m256 second) { return _mm256_and_ps(first, second); } +INTGEMM_AVX2 static inline __m256i and_si(__m256i a, __m256i b) { + return _mm256_and_si256(a, b); +} INTGEMM_AVX2 static inline __m256 cvtepi32_ps(__m256i arg) { return _mm256_cvtepi32_ps(arg); } @@ -192,6 +218,9 @@ INTGEMM_AVX2 static inline __m256d mul_pd(__m256d a, __m256d b) { INTGEMM_AVX2 static inline __m256 mul_ps(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); } +template <> INTGEMM_AVX2 inline __m256i set1_epi8<__m256i>(int8_t to) { + return _mm256_set1_epi8(to); +} template <> INTGEMM_AVX2 inline __m256i set1_epi16<__m256i>(int16_t to) { return _mm256_set1_epi16(to); } @@ -204,6 +233,9 @@ template <> INTGEMM_AVX2 inline __m256d set1_pd<__m256d>(double to) { template <> INTGEMM_AVX2 inline __m256 set1_ps<__m256>(float to) { return _mm256_set1_ps(to); } +template <> INTGEMM_AVX2 inline __m256d setzero_pd<__m256d>() { + return _mm256_setzero_pd(); +} template <> INTGEMM_AVX2 inline __m256 setzero_ps<__m256>() { return _mm256_setzero_ps(); } @@ -233,6 +265,12 @@ INTGEMM_AVX2 static inline __m256 sub_ps(__m256 a, __m256 b) { INTGEMM_AVX512BW static inline __m512i abs_epi8(__m512i arg) { return _mm512_abs_epi8(arg); } +INTGEMM_AVX512BW static inline __m512i add_epi8(__m512i a, __m512i b) { + return _mm512_add_epi8(a, b); +} +INTGEMM_AVX512BW static inline __m512i add_epi16(__m512i a, __m512i b) { + return _mm512_add_epi16(a, b); +} INTGEMM_AVX512BW static inline __m512i add_epi32(__m512i first, __m512i second) { return _mm512_add_epi32(first, second); } @@ -248,6 +286,9 @@ INTGEMM_AVX512BW static inline __m512 add_ps(__m512 a, __m512 b) { INTGEMM_AVX512DQ static inline __m512 and_ps(__m512 first, __m512 second) { return _mm512_and_ps(first, second); } +INTGEMM_AVX512BW static inline __m512i and_si(__m512i a, __m512i b) { + return _mm512_and_si512(a, b); +} INTGEMM_AVX512BW static inline __m512 cvtepi32_ps(__m512i arg) { return _mm512_cvtepi32_ps(arg); } @@ -294,6 +335,9 @@ INTGEMM_AVX512BW static inline __m512d mul_pd(__m512d a, __m512d b) { INTGEMM_AVX512BW static inline __m512 mul_ps(__m512 a, __m512 b) { return _mm512_mul_ps(a, b); } +template <> inline INTGEMM_AVX512BW __m512i set1_epi8<__m512i>(int8_t to) { + return _mm512_set1_epi8(to); +} template <> inline INTGEMM_AVX512BW __m512i set1_epi16<__m512i>(int16_t to) { return _mm512_set1_epi16(to); } @@ -306,6 +350,9 @@ template <> inline INTGEMM_AVX512BW __m512d set1_pd<__m512d>(double to) { template <> inline INTGEMM_AVX512BW __m512 set1_ps<__m512>(float to) { return _mm512_set1_ps(to); } +template <> INTGEMM_AVX512BW inline __m512d setzero_pd<__m512d>() { + return _mm512_setzero_pd(); +} template <> INTGEMM_AVX512BW inline __m512 setzero_ps<__m512>() { return _mm512_setzero_ps(); } diff --git a/kernels/implementations.inl b/kernels/implementations.inl index 5de6a1e..fd46390 100644 --- a/kernels/implementations.inl +++ b/kernels/implementations.inl @@ -31,6 +31,14 @@ namespace kernels { /* * Write */ +CPU_ATTR static inline void write(vi input, int8_t* output, Index offset) { + *reinterpret_cast<vi*>(output + offset) = input; +} + +CPU_ATTR static inline void write(vi input, int16_t* output, Index offset) { + *reinterpret_cast<vi*>(output + offset) = input; +} + CPU_ATTR static inline void write(vi input, int* output, Index offset) { *reinterpret_cast<vi*>(output + offset) = input; } @@ -60,18 +68,60 @@ CPU_ATTR static inline vf unquantize(vi input, vf unquant_mult) { /* * Add a bias term */ +CPU_ATTR static inline vi add_bias(vi input, const int8_t* bias_addr, Index bias_offset) { + auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset); + return add_epi8(input, bias_term); +} + +CPU_ATTR static inline vi add_bias(vi input, const int16_t* bias_addr, Index bias_offset) { + auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset); + return add_epi16(input, bias_term); +} + +CPU_ATTR static inline vi add_bias(vi input, const int* bias_addr, Index bias_offset) { + auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset); + return add_epi32(input, bias_term); +} + CPU_ATTR static inline vf add_bias(vf input, const float* bias_addr, Index bias_offset) { auto bias_term = *reinterpret_cast<const vf*>(bias_addr + bias_offset); return add_ps(input, bias_term); } +CPU_ATTR static inline vd add_bias(vd input, const double* bias_addr, Index bias_offset) { + auto bias_term = *reinterpret_cast<const vd*>(bias_addr + bias_offset); + return add_pd(input, bias_term); +} + /* * ReLU */ -CPU_ATTR static inline vi relu(vi input) { +template <typename Type> +CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> relu(vector_t<CPUType::CPU_NAME, Type> input); + +template <> +CPU_ATTR inline vi relu<int8_t>(vi input) { + static const auto vconst_zero = set1_epi8<vi>(0); +#if defined(THIS_IS_SSE2) + return and_si(input, _mm_cmplt_epi8(vconst_zero, input)); +#elif defined(THIS_IS_AVX2) + return _mm256_max_epi8(input, vconst_zero); +#else + return _mm512_max_epi8(input, vconst_zero); +#endif +} + +template <> +CPU_ATTR inline vi relu<int16_t>(vi input) { + static const auto vconst_zero = set1_epi16<vi>(0); + return max_epi16(input, vconst_zero); +} + +template <> +CPU_ATTR inline vi relu<int>(vi input) { static const auto vconst_zero = set1_epi32<vi>(0); #if defined(THIS_IS_SSE2) - return _mm_and_si128(input, _mm_cmplt_epi32(vconst_zero, input)); + return and_si(input, _mm_cmplt_epi32(vconst_zero, input)); #elif defined(THIS_IS_AVX2) return _mm256_max_epi32(input, vconst_zero); #else @@ -79,33 +129,22 @@ CPU_ATTR static inline vi relu(vi input) { #endif } -CPU_ATTR static inline vf relu(vf input) { - static const auto vconst_zero = set1_ps<vf>(0); +template <> +CPU_ATTR inline vf relu<float>(vf input) { + static const auto vconst_zero = setzero_ps<vf>(); return max_ps(input, vconst_zero); } -CPU_ATTR static inline vd relu(vd input) { - static const auto vconst_zero = set1_pd<vd>(0); +template <> +CPU_ATTR inline vd relu<double>(vd input) { + static const auto vconst_zero = setzero_pd<vd>(); return max_pd(input, vconst_zero); } /* - * Highway: weight * input1 + ([1] - weight) * input2, [0] <= weight <= [1] - */ -CPU_ATTR static inline vf highway(vf input1, vf input2, vf weight) { - static const auto vconst_one = set1_ps<vf>(1.f); - return add_ps(mul_ps(input1, weight), mul_ps(input2, sub_ps(vconst_one, weight))); -} - -CPU_ATTR static inline vd highway(vd input1, vd input2, vd weight) { - static const auto vconst_one = set1_pd<vd>(1.f); - return add_pd(mul_pd(input1, weight), mul_pd(input2, sub_pd(vconst_one, weight))); -} - -/* - * Calculate floor: float -> float + * Floor */ -CPU_ATTR static inline vf floor_ff(vf input) { +CPU_ATTR static inline vf floor(vf input) { #if defined(THIS_IS_SSE2) static const auto vconst_zero = setzero_ps<vf>(); static const auto vconst_one = set1_ps<vf>(1.f); @@ -167,7 +206,7 @@ CPU_ATTR static inline vf exp_approx_taylor(vf x) { x = max_ps(x, const_min_x); x = min_ps(x, const_max_x); - auto a = floor_ff(x); + auto a = floor(x); auto xa = sub_ps(x, a); auto result = mul_ps(dividers[0], xa); diff --git a/test/kernels/add_bias_test.cc b/test/kernels/add_bias_test.cc index a873b4f..3c4a593 100644 --- a/test/kernels/add_bias_test.cc +++ b/test/kernels/add_bias_test.cc @@ -23,18 +23,42 @@ void kernel_add_bias_test() { *output.template as<vec_t>() = kernels::add_bias(*input.template as<vec_t>(), bias.begin(), 0); for (auto i = 0; i < output.size(); ++i) - CHECK(output[i] == 100 + i); + CHECK(output[i] == ElemType_(100 + i)); } +template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, int8_t>(); +template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, int16_t>(); +template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, int>(); template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, float>(); +template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, double>(); +KERNEL_TEST_CASE("add_bias/int8 SSE2") { return kernel_add_bias_test<CPUType::SSE2, int8_t>(); } +KERNEL_TEST_CASE("add_bias/int16 SSE2") { return kernel_add_bias_test<CPUType::SSE2, int16_t>(); } +KERNEL_TEST_CASE("add_bias/int SSE2") { return kernel_add_bias_test<CPUType::SSE2, int>(); } KERNEL_TEST_CASE("add_bias/float SSE2") { return kernel_add_bias_test<CPUType::SSE2, float>(); } +KERNEL_TEST_CASE("add_bias/double SSE2") { return kernel_add_bias_test<CPUType::SSE2, double>(); } +template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int8_t>(); +template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int16_t>(); +template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int>(); template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, float>(); +template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, double>(); +KERNEL_TEST_CASE("add_bias/int8 AVX2") { return kernel_add_bias_test<CPUType::AVX2, int8_t>(); } +KERNEL_TEST_CASE("add_bias/int16 AVX2") { return kernel_add_bias_test<CPUType::AVX2, int16_t>(); } +KERNEL_TEST_CASE("add_bias/int AVX2") { return kernel_add_bias_test<CPUType::AVX2, int>(); } KERNEL_TEST_CASE("add_bias/float AVX2") { return kernel_add_bias_test<CPUType::AVX2, float>(); } +KERNEL_TEST_CASE("add_bias/double AVX2") { return kernel_add_bias_test<CPUType::AVX2, double>(); } #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 +template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, int8_t>(); +template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, int16_t>(); +template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, int>(); template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, float>(); +template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, double>(); +KERNEL_TEST_CASE("add_bias/int8 AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, int8_t>(); } +KERNEL_TEST_CASE("add_bias/int16 AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, int16_t>(); } +KERNEL_TEST_CASE("add_bias/int AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, int>(); } KERNEL_TEST_CASE("add_bias/float AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, float>(); } +KERNEL_TEST_CASE("add_bias/double AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, double>(); } #endif } diff --git a/test/kernels/floor_ff_test.cc b/test/kernels/floor_test.cc index 0f36229..8f21af3 100644 --- a/test/kernels/floor_ff_test.cc +++ b/test/kernels/floor_test.cc @@ -7,7 +7,7 @@ namespace intgemm { template <CPUType CPUType_> -void kernel_floor_ff_test() { +void kernel_floor_test() { if (kCPU < CPUType_) return; @@ -19,20 +19,20 @@ void kernel_floor_ff_test() { std::iota(input.begin(), input.end(), -int(VECTOR_LENGTH / 2)); - *output.template as<vec_t>() = kernels::floor_ff(*input.template as<vec_t>()); + *output.template as<vec_t>() = kernels::floor(*input.template as<vec_t>()); for (auto i = 0; i < output.size(); ++i) CHECK(output[i] == std::floor(input[i])); } -template INTGEMM_SSE2 void kernel_floor_ff_test<CPUType::SSE2>(); -KERNEL_TEST_CASE("floor_ff SSE2") { return kernel_floor_ff_test<CPUType::SSE2>(); } +template INTGEMM_SSE2 void kernel_floor_test<CPUType::SSE2>(); +KERNEL_TEST_CASE("floor SSE2") { return kernel_floor_test<CPUType::SSE2>(); } -template INTGEMM_AVX2 void kernel_floor_ff_test<CPUType::AVX2>(); -KERNEL_TEST_CASE("floor_ff AVX2") { return kernel_floor_ff_test<CPUType::AVX2>(); } +template INTGEMM_AVX2 void kernel_floor_test<CPUType::AVX2>(); +KERNEL_TEST_CASE("floor AVX2") { return kernel_floor_test<CPUType::AVX2>(); } #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 -template INTGEMM_AVX512BW void kernel_floor_ff_test<CPUType::AVX512BW>(); -KERNEL_TEST_CASE("floor_ff AVX512BW") { return kernel_floor_ff_test<CPUType::AVX512BW>(); } +template INTGEMM_AVX512BW void kernel_floor_test<CPUType::AVX512BW>(); +KERNEL_TEST_CASE("floor AVX512BW") { return kernel_floor_test<CPUType::AVX512BW>(); } #endif } diff --git a/test/kernels/highway_test.cc b/test/kernels/highway_test.cc deleted file mode 100644 index 31b9737..0000000 --- a/test/kernels/highway_test.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "test/test.h" -#include "aligned.h" -#include "kernels.h" - -#include <numeric> - -namespace intgemm { - -template <CPUType CPUType_, typename ElemType_> -void kernel_highway_test() { - if (kCPU < CPUType_) - return; - - using vec_t = vector_t<CPUType_, ElemType_>; - constexpr static auto VECTOR_LENGTH = sizeof(vec_t) / sizeof(ElemType_); - - AlignedVector<ElemType_> input1(VECTOR_LENGTH); - AlignedVector<ElemType_> input2(VECTOR_LENGTH); - AlignedVector<ElemType_> weight(VECTOR_LENGTH); - AlignedVector<ElemType_> output(VECTOR_LENGTH); - - std::iota(input1.begin(), input1.end(), 0); - std::iota(input2.begin(), input2.end(), 100); - std::fill(weight.begin(), weight.end(), 0.1); - - *output.template as<vec_t>() = kernels::highway(*input1.template as<vec_t>(), *input2.template as<vec_t>(), *weight.template as<vec_t>()); - for (auto i = 0; i < output.size(); ++i) - CHECK_EPS(output[i], input1[i] * weight[0] + input2[i] * (1 - weight[0]), 0.00001); -} - -template INTGEMM_SSE2 void kernel_highway_test<CPUType::SSE2, float>(); -template INTGEMM_SSE2 void kernel_highway_test<CPUType::SSE2, double>(); -KERNEL_TEST_CASE("highway/float SSE2") { return kernel_highway_test<CPUType::SSE2, float>(); } -KERNEL_TEST_CASE("highway/double SSE2") { return kernel_highway_test<CPUType::SSE2, double>(); } - -template INTGEMM_AVX2 void kernel_highway_test<CPUType::AVX2, float>(); -template INTGEMM_AVX2 void kernel_highway_test<CPUType::AVX2, double>(); -KERNEL_TEST_CASE("highway/float AVX2") { return kernel_highway_test<CPUType::AVX2, float>(); } -KERNEL_TEST_CASE("highway/double AVX2") { return kernel_highway_test<CPUType::AVX2, double>(); } - -#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 -template INTGEMM_AVX512BW void kernel_highway_test<CPUType::AVX512BW, float>(); -template INTGEMM_AVX512BW void kernel_highway_test<CPUType::AVX512BW, double>(); -KERNEL_TEST_CASE("highway/float AVX512BW") { return kernel_highway_test<CPUType::AVX512BW, float>(); } -KERNEL_TEST_CASE("highway/double AVX512BW") { return kernel_highway_test<CPUType::AVX512BW, double>(); } -#endif - -} diff --git a/test/kernels/relu_test.cc b/test/kernels/relu_test.cc index f9fe684..7631623 100644 --- a/test/kernels/relu_test.cc +++ b/test/kernels/relu_test.cc @@ -19,29 +19,41 @@ void kernel_relu_test() { std::iota(input.begin(), input.end(), -int(VECTOR_LENGTH / 2)); - *output.template as<vec_t>() = kernels::relu(*input.template as<vec_t>()); + *output.template as<vec_t>() = kernels::relu<ElemType_>(*input.template as<vec_t>()); for (auto i = 0; i < output.size(); ++i) CHECK(output[i] == (input[i] < 0 ? 0 : input[i])); } +template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int8_t>(); +template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int16_t>(); template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int>(); template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, float>(); template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, double>(); +KERNEL_TEST_CASE("relu/int8 SSE2") { return kernel_relu_test<CPUType::SSE2, int8_t>(); } +KERNEL_TEST_CASE("relu/int16 SSE2") { return kernel_relu_test<CPUType::SSE2, int16_t>(); } KERNEL_TEST_CASE("relu/int SSE2") { return kernel_relu_test<CPUType::SSE2, int>(); } KERNEL_TEST_CASE("relu/float SSE2") { return kernel_relu_test<CPUType::SSE2, float>(); } KERNEL_TEST_CASE("relu/double SSE2") { return kernel_relu_test<CPUType::SSE2, double>(); } +template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int8_t>(); +template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int16_t>(); template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int>(); template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, float>(); template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, double>(); +KERNEL_TEST_CASE("relu/int8 AVX2") { return kernel_relu_test<CPUType::AVX2, int8_t>(); } +KERNEL_TEST_CASE("relu/int16 AVX2") { return kernel_relu_test<CPUType::AVX2, int16_t>(); } KERNEL_TEST_CASE("relu/int AVX2") { return kernel_relu_test<CPUType::AVX2, int>(); } KERNEL_TEST_CASE("relu/float AVX2") { return kernel_relu_test<CPUType::AVX2, float>(); } KERNEL_TEST_CASE("relu/double AVX2") { return kernel_relu_test<CPUType::AVX2, double>(); } #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 +template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int8_t>(); +template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int16_t>(); template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int>(); template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, float>(); template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, double>(); +KERNEL_TEST_CASE("relu/int8 AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int8_t>(); } +KERNEL_TEST_CASE("relu/int16 AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int16_t>(); } KERNEL_TEST_CASE("relu/int AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int>(); } KERNEL_TEST_CASE("relu/float AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, float>(); } KERNEL_TEST_CASE("relu/double AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, double>(); } diff --git a/test/kernels/write_test.cc b/test/kernels/write_test.cc index a1834df..8d85600 100644 --- a/test/kernels/write_test.cc +++ b/test/kernels/write_test.cc @@ -21,27 +21,39 @@ void kernel_write_test() { kernels::write(*input.template as<vec_t>(), output.begin(), 0); for (auto i = 0; i < VECTOR_LENGTH; ++i) - CHECK(output[i] == i); + CHECK(output[i] == ElemType_(i)); } +template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, int8_t>(); +template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, int16_t>(); template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, int>(); template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, float>(); template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, double>(); +KERNEL_TEST_CASE("write/int8 SSE2") { return kernel_write_test<CPUType::SSE2, int8_t>(); } +KERNEL_TEST_CASE("write/int16 SSE2") { return kernel_write_test<CPUType::SSE2, int16_t>(); } KERNEL_TEST_CASE("write/int SSE2") { return kernel_write_test<CPUType::SSE2, int>(); } KERNEL_TEST_CASE("write/float SSE2") { return kernel_write_test<CPUType::SSE2, float>(); } KERNEL_TEST_CASE("write/double SSE2") { return kernel_write_test<CPUType::SSE2, double>(); } +template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int8_t>(); +template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int16_t>(); template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int>(); template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, float>(); template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, double>(); +KERNEL_TEST_CASE("write/int8 AVX2") { return kernel_write_test<CPUType::AVX2, int8_t>(); } +KERNEL_TEST_CASE("write/int16 AVX2") { return kernel_write_test<CPUType::AVX2, int16_t>(); } KERNEL_TEST_CASE("write/int AVX2") { return kernel_write_test<CPUType::AVX2, int>(); } KERNEL_TEST_CASE("write/float AVX2") { return kernel_write_test<CPUType::AVX2, float>(); } KERNEL_TEST_CASE("write/double AVX2") { return kernel_write_test<CPUType::AVX2, double>(); } #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 +template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, int8_t>(); +template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, int16_t>(); template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, int>(); template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, float>(); template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, double>(); +KERNEL_TEST_CASE("write/int8 AVX512BW") { return kernel_write_test<CPUType::AVX512BW, int8_t>(); } +KERNEL_TEST_CASE("write/int16 AVX512BW") { return kernel_write_test<CPUType::AVX512BW, int16_t>(); } KERNEL_TEST_CASE("write/int AVX512BW") { return kernel_write_test<CPUType::AVX512BW, int>(); } KERNEL_TEST_CASE("write/float AVX512BW") { return kernel_write_test<CPUType::AVX512BW, float>(); } KERNEL_TEST_CASE("write/double AVX512BW") { return kernel_write_test<CPUType::AVX512BW, double>(); } diff --git a/vec_traits.h b/vec_traits.h index 4bf369d..33514f9 100644 --- a/vec_traits.h +++ b/vec_traits.h @@ -8,12 +8,18 @@ namespace intgemm { * Vector traits */ template <CPUType CPUType_, typename ElemType_> struct vector_s; +template <> struct vector_s<CPUType::SSE2, int8_t> { using type = __m128i; }; +template <> struct vector_s<CPUType::SSE2, int16_t> { using type = __m128i; }; template <> struct vector_s<CPUType::SSE2, int> { using type = __m128i; }; template <> struct vector_s<CPUType::SSE2, float> { using type = __m128; }; template <> struct vector_s<CPUType::SSE2, double> { using type = __m128d; }; +template <> struct vector_s<CPUType::AVX2, int8_t> { using type = __m256i; }; +template <> struct vector_s<CPUType::AVX2, int16_t> { using type = __m256i; }; template <> struct vector_s<CPUType::AVX2, int> { using type = __m256i; }; template <> struct vector_s<CPUType::AVX2, float> { using type = __m256; }; template <> struct vector_s<CPUType::AVX2, double> { using type = __m256d; }; +template <> struct vector_s<CPUType::AVX512BW, int8_t> { using type = __m512i; }; +template <> struct vector_s<CPUType::AVX512BW, int16_t> { using type = __m512i; }; template <> struct vector_s<CPUType::AVX512BW, int> { using type = __m512i; }; template <> struct vector_s<CPUType::AVX512BW, float> { using type = __m512; }; template <> struct vector_s<CPUType::AVX512BW, double> { using type = __m512d; }; |