Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-19 19:15:58 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-19 19:16:37 +0300
commita80efb933528ffbad8d17f7f7a915aeee1c3e0e7 (patch)
treecb1e654e72398622014528aba9acb0ff178dcf1f
parentd02f3ce07214122f1f36f2fdd0379db8a9abd409 (diff)
Add vec_traits for int8 and int16
-rw-r--r--CMakeLists.txt3
-rw-r--r--intrinsics.h47
-rw-r--r--kernels/implementations.inl83
-rw-r--r--test/kernels/add_bias_test.cc26
-rw-r--r--test/kernels/floor_test.cc (renamed from test/kernels/floor_ff_test.cc)16
-rw-r--r--test/kernels/highway_test.cc48
-rw-r--r--test/kernels/relu_test.cc14
-rw-r--r--test/kernels/write_test.cc14
-rw-r--r--vec_traits.h6
9 files changed, 174 insertions, 83 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ff4048e..2a15175 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -48,8 +48,7 @@ add_executable(tests
# Kernels tests
test/kernels/add_bias_test.cc
test/kernels/exp_test.cc
- test/kernels/floor_ff_test.cc
- test/kernels/highway_test.cc
+ test/kernels/floor_test.cc
test/kernels/quantize_test.cc
test/kernels/relu_test.cc
test/kernels/sigmoid_test.cc
diff --git a/intrinsics.h b/intrinsics.h
index 92b06c9..f6c63a9 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -17,10 +17,12 @@ namespace intgemm {
* templates.
*/
template <class Register> static inline Register loadu_ps(const float* mem_addr);
+template <class Register> static inline Register set1_epi8(int8_t to);
template <class Register> static inline Register set1_epi16(int16_t to);
template <class Register> static inline Register set1_epi32(int32_t to);
template <class Register> static inline Register set1_pd(double to);
template <class Register> static inline Register set1_ps(float to);
+template <class Register> static inline Register setzero_pd();
template <class Register> static inline Register setzero_ps();
template <class Register> static inline Register setzero_si();
@@ -32,6 +34,12 @@ template <class Register> static inline Register setzero_si();
INTGEMM_SSSE3 static inline __m128i abs_epi8(__m128i arg) {
return _mm_abs_epi8(arg);
}
+INTGEMM_SSE2 static inline __m128i add_epi8(__m128i a, __m128i b) {
+ return _mm_add_epi8(a, b);
+}
+INTGEMM_SSE2 static inline __m128i add_epi16(__m128i a, __m128i b) {
+ return _mm_add_epi16(a, b);
+}
INTGEMM_SSE2 static inline __m128i add_epi32(__m128i first, __m128i second) {
return _mm_add_epi32(first, second);
}
@@ -47,6 +55,9 @@ INTGEMM_SSE2 static inline __m128 add_ps(__m128 a, __m128 b) {
INTGEMM_SSE2 static inline __m128 and_ps(__m128 first, __m128 second) {
return _mm_and_ps(first, second);
}
+INTGEMM_SSE2 static inline __m128i and_si(__m128i a, __m128i b) {
+ return _mm_and_si128(a, b);
+}
INTGEMM_SSE2 static inline __m128 cvtepi32_ps(__m128i arg) {
return _mm_cvtepi32_ps(arg);
}
@@ -92,6 +103,9 @@ INTGEMM_SSE2 static inline __m128d mul_pd(__m128d a, __m128d b) {
INTGEMM_SSE2 static inline __m128 mul_ps(__m128 a, __m128 b) {
return _mm_mul_ps(a, b);
}
+template <> INTGEMM_SSE2 inline __m128i set1_epi8<__m128i>(int8_t to) {
+ return _mm_set1_epi8(to);
+}
template <> INTGEMM_SSE2 inline __m128i set1_epi16<__m128i>(int16_t to) {
return _mm_set1_epi16(to);
}
@@ -104,6 +118,9 @@ template <> INTGEMM_SSE2 inline __m128d set1_pd<__m128d>(double to) {
template <> INTGEMM_SSE2 inline __m128 set1_ps<__m128>(float to) {
return _mm_set1_ps(to);
}
+template <> INTGEMM_SSE2 inline __m128d setzero_pd<__m128d>() {
+ return _mm_setzero_pd();
+}
template <> INTGEMM_SSE2 inline __m128 setzero_ps<__m128>() {
return _mm_setzero_ps();
}
@@ -131,6 +148,12 @@ INTGEMM_SSE2 static inline __m128 sub_ps(__m128 a, __m128 b) {
INTGEMM_AVX2 static inline __m256i abs_epi8(__m256i arg) {
return _mm256_abs_epi8(arg);
}
+INTGEMM_AVX2 static inline __m256i add_epi8(__m256i a, __m256i b) {
+ return _mm256_add_epi8(a, b);
+}
+INTGEMM_AVX2 static inline __m256i add_epi16(__m256i a, __m256i b) {
+ return _mm256_add_epi16(a, b);
+}
INTGEMM_AVX2 static inline __m256i add_epi32(__m256i first, __m256i second) {
return _mm256_add_epi32(first, second);
}
@@ -146,6 +169,9 @@ INTGEMM_AVX2 static inline __m256 add_ps(__m256 a, __m256 b) {
INTGEMM_AVX2 static inline __m256 and_ps(__m256 first, __m256 second) {
return _mm256_and_ps(first, second);
}
+INTGEMM_AVX2 static inline __m256i and_si(__m256i a, __m256i b) {
+ return _mm256_and_si256(a, b);
+}
INTGEMM_AVX2 static inline __m256 cvtepi32_ps(__m256i arg) {
return _mm256_cvtepi32_ps(arg);
}
@@ -192,6 +218,9 @@ INTGEMM_AVX2 static inline __m256d mul_pd(__m256d a, __m256d b) {
INTGEMM_AVX2 static inline __m256 mul_ps(__m256 a, __m256 b) {
return _mm256_mul_ps(a, b);
}
+template <> INTGEMM_AVX2 inline __m256i set1_epi8<__m256i>(int8_t to) {
+ return _mm256_set1_epi8(to);
+}
template <> INTGEMM_AVX2 inline __m256i set1_epi16<__m256i>(int16_t to) {
return _mm256_set1_epi16(to);
}
@@ -204,6 +233,9 @@ template <> INTGEMM_AVX2 inline __m256d set1_pd<__m256d>(double to) {
template <> INTGEMM_AVX2 inline __m256 set1_ps<__m256>(float to) {
return _mm256_set1_ps(to);
}
+template <> INTGEMM_AVX2 inline __m256d setzero_pd<__m256d>() {
+ return _mm256_setzero_pd();
+}
template <> INTGEMM_AVX2 inline __m256 setzero_ps<__m256>() {
return _mm256_setzero_ps();
}
@@ -233,6 +265,12 @@ INTGEMM_AVX2 static inline __m256 sub_ps(__m256 a, __m256 b) {
INTGEMM_AVX512BW static inline __m512i abs_epi8(__m512i arg) {
return _mm512_abs_epi8(arg);
}
+INTGEMM_AVX512BW static inline __m512i add_epi8(__m512i a, __m512i b) {
+ return _mm512_add_epi8(a, b);
+}
+INTGEMM_AVX512BW static inline __m512i add_epi16(__m512i a, __m512i b) {
+ return _mm512_add_epi16(a, b);
+}
INTGEMM_AVX512BW static inline __m512i add_epi32(__m512i first, __m512i second) {
return _mm512_add_epi32(first, second);
}
@@ -248,6 +286,9 @@ INTGEMM_AVX512BW static inline __m512 add_ps(__m512 a, __m512 b) {
INTGEMM_AVX512DQ static inline __m512 and_ps(__m512 first, __m512 second) {
return _mm512_and_ps(first, second);
}
+INTGEMM_AVX512BW static inline __m512i and_si(__m512i a, __m512i b) {
+ return _mm512_and_si512(a, b);
+}
INTGEMM_AVX512BW static inline __m512 cvtepi32_ps(__m512i arg) {
return _mm512_cvtepi32_ps(arg);
}
@@ -294,6 +335,9 @@ INTGEMM_AVX512BW static inline __m512d mul_pd(__m512d a, __m512d b) {
INTGEMM_AVX512BW static inline __m512 mul_ps(__m512 a, __m512 b) {
return _mm512_mul_ps(a, b);
}
+template <> inline INTGEMM_AVX512BW __m512i set1_epi8<__m512i>(int8_t to) {
+ return _mm512_set1_epi8(to);
+}
template <> inline INTGEMM_AVX512BW __m512i set1_epi16<__m512i>(int16_t to) {
return _mm512_set1_epi16(to);
}
@@ -306,6 +350,9 @@ template <> inline INTGEMM_AVX512BW __m512d set1_pd<__m512d>(double to) {
template <> inline INTGEMM_AVX512BW __m512 set1_ps<__m512>(float to) {
return _mm512_set1_ps(to);
}
+template <> INTGEMM_AVX512BW inline __m512d setzero_pd<__m512d>() {
+ return _mm512_setzero_pd();
+}
template <> INTGEMM_AVX512BW inline __m512 setzero_ps<__m512>() {
return _mm512_setzero_ps();
}
diff --git a/kernels/implementations.inl b/kernels/implementations.inl
index 5de6a1e..fd46390 100644
--- a/kernels/implementations.inl
+++ b/kernels/implementations.inl
@@ -31,6 +31,14 @@ namespace kernels {
/*
* Write
*/
+CPU_ATTR static inline void write(vi input, int8_t* output, Index offset) {
+ *reinterpret_cast<vi*>(output + offset) = input;
+}
+
+CPU_ATTR static inline void write(vi input, int16_t* output, Index offset) {
+ *reinterpret_cast<vi*>(output + offset) = input;
+}
+
CPU_ATTR static inline void write(vi input, int* output, Index offset) {
*reinterpret_cast<vi*>(output + offset) = input;
}
@@ -60,18 +68,60 @@ CPU_ATTR static inline vf unquantize(vi input, vf unquant_mult) {
/*
* Add a bias term
*/
+CPU_ATTR static inline vi add_bias(vi input, const int8_t* bias_addr, Index bias_offset) {
+ auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset);
+ return add_epi8(input, bias_term);
+}
+
+CPU_ATTR static inline vi add_bias(vi input, const int16_t* bias_addr, Index bias_offset) {
+ auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset);
+ return add_epi16(input, bias_term);
+}
+
+CPU_ATTR static inline vi add_bias(vi input, const int* bias_addr, Index bias_offset) {
+ auto bias_term = *reinterpret_cast<const vi*>(bias_addr + bias_offset);
+ return add_epi32(input, bias_term);
+}
+
CPU_ATTR static inline vf add_bias(vf input, const float* bias_addr, Index bias_offset) {
auto bias_term = *reinterpret_cast<const vf*>(bias_addr + bias_offset);
return add_ps(input, bias_term);
}
+CPU_ATTR static inline vd add_bias(vd input, const double* bias_addr, Index bias_offset) {
+ auto bias_term = *reinterpret_cast<const vd*>(bias_addr + bias_offset);
+ return add_pd(input, bias_term);
+}
+
/*
* ReLU
*/
-CPU_ATTR static inline vi relu(vi input) {
+template <typename Type>
+CPU_ATTR static inline vector_t<CPUType::CPU_NAME, Type> relu(vector_t<CPUType::CPU_NAME, Type> input);
+
+template <>
+CPU_ATTR inline vi relu<int8_t>(vi input) {
+ static const auto vconst_zero = set1_epi8<vi>(0);
+#if defined(THIS_IS_SSE2)
+ return and_si(input, _mm_cmplt_epi8(vconst_zero, input));
+#elif defined(THIS_IS_AVX2)
+ return _mm256_max_epi8(input, vconst_zero);
+#else
+ return _mm512_max_epi8(input, vconst_zero);
+#endif
+}
+
+template <>
+CPU_ATTR inline vi relu<int16_t>(vi input) {
+ static const auto vconst_zero = set1_epi16<vi>(0);
+ return max_epi16(input, vconst_zero);
+}
+
+template <>
+CPU_ATTR inline vi relu<int>(vi input) {
static const auto vconst_zero = set1_epi32<vi>(0);
#if defined(THIS_IS_SSE2)
- return _mm_and_si128(input, _mm_cmplt_epi32(vconst_zero, input));
+ return and_si(input, _mm_cmplt_epi32(vconst_zero, input));
#elif defined(THIS_IS_AVX2)
return _mm256_max_epi32(input, vconst_zero);
#else
@@ -79,33 +129,22 @@ CPU_ATTR static inline vi relu(vi input) {
#endif
}
-CPU_ATTR static inline vf relu(vf input) {
- static const auto vconst_zero = set1_ps<vf>(0);
+template <>
+CPU_ATTR inline vf relu<float>(vf input) {
+ static const auto vconst_zero = setzero_ps<vf>();
return max_ps(input, vconst_zero);
}
-CPU_ATTR static inline vd relu(vd input) {
- static const auto vconst_zero = set1_pd<vd>(0);
+template <>
+CPU_ATTR inline vd relu<double>(vd input) {
+ static const auto vconst_zero = setzero_pd<vd>();
return max_pd(input, vconst_zero);
}
/*
- * Highway: weight * input1 + ([1] - weight) * input2, [0] <= weight <= [1]
- */
-CPU_ATTR static inline vf highway(vf input1, vf input2, vf weight) {
- static const auto vconst_one = set1_ps<vf>(1.f);
- return add_ps(mul_ps(input1, weight), mul_ps(input2, sub_ps(vconst_one, weight)));
-}
-
-CPU_ATTR static inline vd highway(vd input1, vd input2, vd weight) {
- static const auto vconst_one = set1_pd<vd>(1.f);
- return add_pd(mul_pd(input1, weight), mul_pd(input2, sub_pd(vconst_one, weight)));
-}
-
-/*
- * Calculate floor: float -> float
+ * Floor
*/
-CPU_ATTR static inline vf floor_ff(vf input) {
+CPU_ATTR static inline vf floor(vf input) {
#if defined(THIS_IS_SSE2)
static const auto vconst_zero = setzero_ps<vf>();
static const auto vconst_one = set1_ps<vf>(1.f);
@@ -167,7 +206,7 @@ CPU_ATTR static inline vf exp_approx_taylor(vf x) {
x = max_ps(x, const_min_x);
x = min_ps(x, const_max_x);
- auto a = floor_ff(x);
+ auto a = floor(x);
auto xa = sub_ps(x, a);
auto result = mul_ps(dividers[0], xa);
diff --git a/test/kernels/add_bias_test.cc b/test/kernels/add_bias_test.cc
index a873b4f..3c4a593 100644
--- a/test/kernels/add_bias_test.cc
+++ b/test/kernels/add_bias_test.cc
@@ -23,18 +23,42 @@ void kernel_add_bias_test() {
*output.template as<vec_t>() = kernels::add_bias(*input.template as<vec_t>(), bias.begin(), 0);
for (auto i = 0; i < output.size(); ++i)
- CHECK(output[i] == 100 + i);
+ CHECK(output[i] == ElemType_(100 + i));
}
+template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, int8_t>();
+template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, int16_t>();
+template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, int>();
template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, float>();
+template INTGEMM_SSE2 void kernel_add_bias_test<CPUType::SSE2, double>();
+KERNEL_TEST_CASE("add_bias/int8 SSE2") { return kernel_add_bias_test<CPUType::SSE2, int8_t>(); }
+KERNEL_TEST_CASE("add_bias/int16 SSE2") { return kernel_add_bias_test<CPUType::SSE2, int16_t>(); }
+KERNEL_TEST_CASE("add_bias/int SSE2") { return kernel_add_bias_test<CPUType::SSE2, int>(); }
KERNEL_TEST_CASE("add_bias/float SSE2") { return kernel_add_bias_test<CPUType::SSE2, float>(); }
+KERNEL_TEST_CASE("add_bias/double SSE2") { return kernel_add_bias_test<CPUType::SSE2, double>(); }
+template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int8_t>();
+template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int16_t>();
+template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int>();
template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, float>();
+template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, double>();
+KERNEL_TEST_CASE("add_bias/int8 AVX2") { return kernel_add_bias_test<CPUType::AVX2, int8_t>(); }
+KERNEL_TEST_CASE("add_bias/int16 AVX2") { return kernel_add_bias_test<CPUType::AVX2, int16_t>(); }
+KERNEL_TEST_CASE("add_bias/int AVX2") { return kernel_add_bias_test<CPUType::AVX2, int>(); }
KERNEL_TEST_CASE("add_bias/float AVX2") { return kernel_add_bias_test<CPUType::AVX2, float>(); }
+KERNEL_TEST_CASE("add_bias/double AVX2") { return kernel_add_bias_test<CPUType::AVX2, double>(); }
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, int8_t>();
+template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, int16_t>();
+template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, int>();
template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, float>();
+template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, double>();
+KERNEL_TEST_CASE("add_bias/int8 AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, int8_t>(); }
+KERNEL_TEST_CASE("add_bias/int16 AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, int16_t>(); }
+KERNEL_TEST_CASE("add_bias/int AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, int>(); }
KERNEL_TEST_CASE("add_bias/float AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, float>(); }
+KERNEL_TEST_CASE("add_bias/double AVX512BW") { return kernel_add_bias_test<CPUType::AVX512BW, double>(); }
#endif
}
diff --git a/test/kernels/floor_ff_test.cc b/test/kernels/floor_test.cc
index 0f36229..8f21af3 100644
--- a/test/kernels/floor_ff_test.cc
+++ b/test/kernels/floor_test.cc
@@ -7,7 +7,7 @@
namespace intgemm {
template <CPUType CPUType_>
-void kernel_floor_ff_test() {
+void kernel_floor_test() {
if (kCPU < CPUType_)
return;
@@ -19,20 +19,20 @@ void kernel_floor_ff_test() {
std::iota(input.begin(), input.end(), -int(VECTOR_LENGTH / 2));
- *output.template as<vec_t>() = kernels::floor_ff(*input.template as<vec_t>());
+ *output.template as<vec_t>() = kernels::floor(*input.template as<vec_t>());
for (auto i = 0; i < output.size(); ++i)
CHECK(output[i] == std::floor(input[i]));
}
-template INTGEMM_SSE2 void kernel_floor_ff_test<CPUType::SSE2>();
-KERNEL_TEST_CASE("floor_ff SSE2") { return kernel_floor_ff_test<CPUType::SSE2>(); }
+template INTGEMM_SSE2 void kernel_floor_test<CPUType::SSE2>();
+KERNEL_TEST_CASE("floor SSE2") { return kernel_floor_test<CPUType::SSE2>(); }
-template INTGEMM_AVX2 void kernel_floor_ff_test<CPUType::AVX2>();
-KERNEL_TEST_CASE("floor_ff AVX2") { return kernel_floor_ff_test<CPUType::AVX2>(); }
+template INTGEMM_AVX2 void kernel_floor_test<CPUType::AVX2>();
+KERNEL_TEST_CASE("floor AVX2") { return kernel_floor_test<CPUType::AVX2>(); }
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
-template INTGEMM_AVX512BW void kernel_floor_ff_test<CPUType::AVX512BW>();
-KERNEL_TEST_CASE("floor_ff AVX512BW") { return kernel_floor_ff_test<CPUType::AVX512BW>(); }
+template INTGEMM_AVX512BW void kernel_floor_test<CPUType::AVX512BW>();
+KERNEL_TEST_CASE("floor AVX512BW") { return kernel_floor_test<CPUType::AVX512BW>(); }
#endif
}
diff --git a/test/kernels/highway_test.cc b/test/kernels/highway_test.cc
deleted file mode 100644
index 31b9737..0000000
--- a/test/kernels/highway_test.cc
+++ /dev/null
@@ -1,48 +0,0 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
-
-#include <numeric>
-
-namespace intgemm {
-
-template <CPUType CPUType_, typename ElemType_>
-void kernel_highway_test() {
- if (kCPU < CPUType_)
- return;
-
- using vec_t = vector_t<CPUType_, ElemType_>;
- constexpr static auto VECTOR_LENGTH = sizeof(vec_t) / sizeof(ElemType_);
-
- AlignedVector<ElemType_> input1(VECTOR_LENGTH);
- AlignedVector<ElemType_> input2(VECTOR_LENGTH);
- AlignedVector<ElemType_> weight(VECTOR_LENGTH);
- AlignedVector<ElemType_> output(VECTOR_LENGTH);
-
- std::iota(input1.begin(), input1.end(), 0);
- std::iota(input2.begin(), input2.end(), 100);
- std::fill(weight.begin(), weight.end(), 0.1);
-
- *output.template as<vec_t>() = kernels::highway(*input1.template as<vec_t>(), *input2.template as<vec_t>(), *weight.template as<vec_t>());
- for (auto i = 0; i < output.size(); ++i)
- CHECK_EPS(output[i], input1[i] * weight[0] + input2[i] * (1 - weight[0]), 0.00001);
-}
-
-template INTGEMM_SSE2 void kernel_highway_test<CPUType::SSE2, float>();
-template INTGEMM_SSE2 void kernel_highway_test<CPUType::SSE2, double>();
-KERNEL_TEST_CASE("highway/float SSE2") { return kernel_highway_test<CPUType::SSE2, float>(); }
-KERNEL_TEST_CASE("highway/double SSE2") { return kernel_highway_test<CPUType::SSE2, double>(); }
-
-template INTGEMM_AVX2 void kernel_highway_test<CPUType::AVX2, float>();
-template INTGEMM_AVX2 void kernel_highway_test<CPUType::AVX2, double>();
-KERNEL_TEST_CASE("highway/float AVX2") { return kernel_highway_test<CPUType::AVX2, float>(); }
-KERNEL_TEST_CASE("highway/double AVX2") { return kernel_highway_test<CPUType::AVX2, double>(); }
-
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
-template INTGEMM_AVX512BW void kernel_highway_test<CPUType::AVX512BW, float>();
-template INTGEMM_AVX512BW void kernel_highway_test<CPUType::AVX512BW, double>();
-KERNEL_TEST_CASE("highway/float AVX512BW") { return kernel_highway_test<CPUType::AVX512BW, float>(); }
-KERNEL_TEST_CASE("highway/double AVX512BW") { return kernel_highway_test<CPUType::AVX512BW, double>(); }
-#endif
-
-}
diff --git a/test/kernels/relu_test.cc b/test/kernels/relu_test.cc
index f9fe684..7631623 100644
--- a/test/kernels/relu_test.cc
+++ b/test/kernels/relu_test.cc
@@ -19,29 +19,41 @@ void kernel_relu_test() {
std::iota(input.begin(), input.end(), -int(VECTOR_LENGTH / 2));
- *output.template as<vec_t>() = kernels::relu(*input.template as<vec_t>());
+ *output.template as<vec_t>() = kernels::relu<ElemType_>(*input.template as<vec_t>());
for (auto i = 0; i < output.size(); ++i)
CHECK(output[i] == (input[i] < 0 ? 0 : input[i]));
}
+template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int8_t>();
+template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int16_t>();
template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int>();
template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, float>();
template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, double>();
+KERNEL_TEST_CASE("relu/int8 SSE2") { return kernel_relu_test<CPUType::SSE2, int8_t>(); }
+KERNEL_TEST_CASE("relu/int16 SSE2") { return kernel_relu_test<CPUType::SSE2, int16_t>(); }
KERNEL_TEST_CASE("relu/int SSE2") { return kernel_relu_test<CPUType::SSE2, int>(); }
KERNEL_TEST_CASE("relu/float SSE2") { return kernel_relu_test<CPUType::SSE2, float>(); }
KERNEL_TEST_CASE("relu/double SSE2") { return kernel_relu_test<CPUType::SSE2, double>(); }
+template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int8_t>();
+template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int16_t>();
template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int>();
template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, float>();
template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, double>();
+KERNEL_TEST_CASE("relu/int8 AVX2") { return kernel_relu_test<CPUType::AVX2, int8_t>(); }
+KERNEL_TEST_CASE("relu/int16 AVX2") { return kernel_relu_test<CPUType::AVX2, int16_t>(); }
KERNEL_TEST_CASE("relu/int AVX2") { return kernel_relu_test<CPUType::AVX2, int>(); }
KERNEL_TEST_CASE("relu/float AVX2") { return kernel_relu_test<CPUType::AVX2, float>(); }
KERNEL_TEST_CASE("relu/double AVX2") { return kernel_relu_test<CPUType::AVX2, double>(); }
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int8_t>();
+template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int16_t>();
template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int>();
template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, float>();
template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, double>();
+KERNEL_TEST_CASE("relu/int8 AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int8_t>(); }
+KERNEL_TEST_CASE("relu/int16 AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int16_t>(); }
KERNEL_TEST_CASE("relu/int AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int>(); }
KERNEL_TEST_CASE("relu/float AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, float>(); }
KERNEL_TEST_CASE("relu/double AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, double>(); }
diff --git a/test/kernels/write_test.cc b/test/kernels/write_test.cc
index a1834df..8d85600 100644
--- a/test/kernels/write_test.cc
+++ b/test/kernels/write_test.cc
@@ -21,27 +21,39 @@ void kernel_write_test() {
kernels::write(*input.template as<vec_t>(), output.begin(), 0);
for (auto i = 0; i < VECTOR_LENGTH; ++i)
- CHECK(output[i] == i);
+ CHECK(output[i] == ElemType_(i));
}
+template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, int8_t>();
+template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, int16_t>();
template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, int>();
template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, float>();
template INTGEMM_SSE2 void kernel_write_test<CPUType::SSE2, double>();
+KERNEL_TEST_CASE("write/int8 SSE2") { return kernel_write_test<CPUType::SSE2, int8_t>(); }
+KERNEL_TEST_CASE("write/int16 SSE2") { return kernel_write_test<CPUType::SSE2, int16_t>(); }
KERNEL_TEST_CASE("write/int SSE2") { return kernel_write_test<CPUType::SSE2, int>(); }
KERNEL_TEST_CASE("write/float SSE2") { return kernel_write_test<CPUType::SSE2, float>(); }
KERNEL_TEST_CASE("write/double SSE2") { return kernel_write_test<CPUType::SSE2, double>(); }
+template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int8_t>();
+template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int16_t>();
template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int>();
template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, float>();
template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, double>();
+KERNEL_TEST_CASE("write/int8 AVX2") { return kernel_write_test<CPUType::AVX2, int8_t>(); }
+KERNEL_TEST_CASE("write/int16 AVX2") { return kernel_write_test<CPUType::AVX2, int16_t>(); }
KERNEL_TEST_CASE("write/int AVX2") { return kernel_write_test<CPUType::AVX2, int>(); }
KERNEL_TEST_CASE("write/float AVX2") { return kernel_write_test<CPUType::AVX2, float>(); }
KERNEL_TEST_CASE("write/double AVX2") { return kernel_write_test<CPUType::AVX2, double>(); }
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, int8_t>();
+template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, int16_t>();
template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, int>();
template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, float>();
template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, double>();
+KERNEL_TEST_CASE("write/int8 AVX512BW") { return kernel_write_test<CPUType::AVX512BW, int8_t>(); }
+KERNEL_TEST_CASE("write/int16 AVX512BW") { return kernel_write_test<CPUType::AVX512BW, int16_t>(); }
KERNEL_TEST_CASE("write/int AVX512BW") { return kernel_write_test<CPUType::AVX512BW, int>(); }
KERNEL_TEST_CASE("write/float AVX512BW") { return kernel_write_test<CPUType::AVX512BW, float>(); }
KERNEL_TEST_CASE("write/double AVX512BW") { return kernel_write_test<CPUType::AVX512BW, double>(); }
diff --git a/vec_traits.h b/vec_traits.h
index 4bf369d..33514f9 100644
--- a/vec_traits.h
+++ b/vec_traits.h
@@ -8,12 +8,18 @@ namespace intgemm {
* Vector traits
*/
template <CPUType CPUType_, typename ElemType_> struct vector_s;
+template <> struct vector_s<CPUType::SSE2, int8_t> { using type = __m128i; };
+template <> struct vector_s<CPUType::SSE2, int16_t> { using type = __m128i; };
template <> struct vector_s<CPUType::SSE2, int> { using type = __m128i; };
template <> struct vector_s<CPUType::SSE2, float> { using type = __m128; };
template <> struct vector_s<CPUType::SSE2, double> { using type = __m128d; };
+template <> struct vector_s<CPUType::AVX2, int8_t> { using type = __m256i; };
+template <> struct vector_s<CPUType::AVX2, int16_t> { using type = __m256i; };
template <> struct vector_s<CPUType::AVX2, int> { using type = __m256i; };
template <> struct vector_s<CPUType::AVX2, float> { using type = __m256; };
template <> struct vector_s<CPUType::AVX2, double> { using type = __m256d; };
+template <> struct vector_s<CPUType::AVX512BW, int8_t> { using type = __m512i; };
+template <> struct vector_s<CPUType::AVX512BW, int16_t> { using type = __m512i; };
template <> struct vector_s<CPUType::AVX512BW, int> { using type = __m512i; };
template <> struct vector_s<CPUType::AVX512BW, float> { using type = __m512; };
template <> struct vector_s<CPUType::AVX512BW, double> { using type = __m512d; };