diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-08 17:51:41 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-09 22:39:21 +0300 |
commit | 7e514a4d1178ddeae0cf38fa29f5ca758abf8a9a (patch) | |
tree | 03e7da533b58b66a8f40dea21915ad9a66ee6e9b | |
parent | 5466238858becaec459d154137dbd2d79baa0d3d (diff) |
Add simple vector traits
-rw-r--r-- | multiply.h | 15 | ||||
-rw-r--r-- | postprocess.h | 25 | ||||
-rw-r--r-- | test/postprocess/add_bias_test.cc | 4 | ||||
-rw-r--r-- | test/postprocess/relu_test.cc | 4 | ||||
-rw-r--r-- | test/postprocess/unquantize_test.cc | 4 | ||||
-rw-r--r-- | vec_traits.h | 32 | ||||
-rw-r--r-- | vec_utils.h | 8 |
7 files changed, 59 insertions, 33 deletions
@@ -4,12 +4,13 @@ #include "intrinsics.h" #include "postprocess_pipeline.h" #include "vec_utils.h" +#include "vec_traits.h" namespace intgemm { -INTGEMM_SSE2 static inline void writer(float* C, Index offset, RegisterPair128 result) { - *reinterpret_cast<__m128*>(C + offset) = result.pack0123; - *reinterpret_cast<__m128*>(C + offset + 4) = result.pack4567; +INTGEMM_SSE2 static inline void writer(float* C, Index offset, dvector_t<CPUType::SSE2, float> result) { + *reinterpret_cast<__m128*>(C + offset) = result.first; + *reinterpret_cast<__m128*>(C + offset + 4) = result.second; } INTGEMM_AVX2 static inline void writer(float* C, Index offset, __m256 result) { @@ -33,11 +34,11 @@ INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) { return *reinterpret_cast<float*>(&a); } -INTGEMM_SSE2 static inline RegisterPair128i PermuteSummer(__m128i pack0123, __m128i pack4567) { +INTGEMM_SSE2 static inline dvector_t<CPUType::SSE2, int> PermuteSummer(__m128i pack0123, __m128i pack4567) { // No op for 128 bits: already reduced fully. - RegisterPair128i ret; - ret.pack0123 = pack0123; - ret.pack4567 = pack4567; + dvector_t<CPUType::SSE2, int> ret; + ret.first = pack0123; + ret.second = pack4567; return ret; } diff --git a/postprocess.h b/postprocess.h index 7835b2b..53c5a3e 100644 --- a/postprocess.h +++ b/postprocess.h @@ -4,6 +4,7 @@ #include "postprocess_pipeline.h" #include "types.h" #include "vec_utils.h" +#include "vec_traits.h" // TODO: We support some postprocess in few variations e.g. we support ReLU for // float -> float, int8 -> int8, int16 -> int16. Maybe it would be a good idea @@ -24,8 +25,8 @@ public: template <> class PostprocessImpl<Unquantize, CPUType::SSE2> { public: - using InputRegister = RegisterPair128i; - using OutputRegister = RegisterPair128; + using InputRegister = dvector_t<CPUType::SSE2, int>; + using OutputRegister = dvector_t<CPUType::SSE2, float>; INTGEMM_SSE2 PostprocessImpl(const Unquantize& config) { unquantize_multiplier = set1_ps<__m128>(config.unquantize_multiplier); @@ -33,8 +34,8 @@ public: INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { return { - mul_ps(cvtepi32_ps(input.pack0123), unquantize_multiplier), - mul_ps(cvtepi32_ps(input.pack4567), unquantize_multiplier), + mul_ps(cvtepi32_ps(input.first), unquantize_multiplier), + mul_ps(cvtepi32_ps(input.second), unquantize_multiplier), }; } @@ -96,8 +97,8 @@ public: template <> class PostprocessImpl<AddBias, CPUType::SSE2> { public: - using InputRegister = RegisterPair128; - using OutputRegister = RegisterPair128; + using InputRegister = dvector_t<CPUType::SSE2, float>; + using OutputRegister = dvector_t<CPUType::SSE2, float>; PostprocessImpl(const AddBias& config) : config(config) {} @@ -105,8 +106,8 @@ public: auto bias_term0123 = *reinterpret_cast<const __m128*>(config.bias + (offset % config.length)); auto bias_term4567 = *reinterpret_cast<const __m128*>(config.bias + (offset % config.length) + 4); return { - add_ps(input.pack0123, bias_term0123), - add_ps(input.pack4567, bias_term4567), + add_ps(input.first, bias_term0123), + add_ps(input.second, bias_term4567), }; } @@ -160,16 +161,16 @@ class ReLU {}; template <> class PostprocessImpl<ReLU, CPUType::SSE2> { public: - using InputRegister = RegisterPair128; - using OutputRegister = RegisterPair128; + using InputRegister = dvector_t<CPUType::SSE2, float>; + using OutputRegister = dvector_t<CPUType::SSE2, float>; PostprocessImpl(const ReLU& config) {} INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) { static const auto const_zero = set1_ps<__m128>(0.f); return { - max_ps(const_zero, input.pack0123), - max_ps(const_zero, input.pack4567), + max_ps(const_zero, input.first), + max_ps(const_zero, input.second), }; } }; diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc index 5e893ea..3bc7f74 100644 --- a/test/postprocess/add_bias_test.cc +++ b/test/postprocess/add_bias_test.cc @@ -19,8 +19,8 @@ INTGEMM_SSE2 TEST_CASE("AddBias SSE2",) { auto postproc = PostprocessImpl<AddBias, CPUType::SSE2>(AddBias(bias.begin(), bias.size())); auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0); - output.as<__m128>()[0] = output_tmp.pack0123; - output.as<__m128>()[1] = output_tmp.pack4567; + output.as<__m128>()[0] = output_tmp.first; + output.as<__m128>()[1] = output_tmp.second; CHECK(output[0] == -2.f); // input = -2, bias = 0 CHECK(output[1] == 0.f); // input = -1, bias = 1 diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc index af6677e..a560790 100644 --- a/test/postprocess/relu_test.cc +++ b/test/postprocess/relu_test.cc @@ -19,8 +19,8 @@ INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) { auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU()); auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0); - output.as<__m128>()[0] = output_tmp.pack0123; - output.as<__m128>()[1] = output_tmp.pack4567; + output.as<__m128>()[0] = output_tmp.first; + output.as<__m128>()[1] = output_tmp.second; CHECK(output[0] == 0.f); // input = -2 CHECK(output[1] == 0.f); // input = -1 diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc index c33b909..45e6bc4 100644 --- a/test/postprocess/unquantize_test.cc +++ b/test/postprocess/unquantize_test.cc @@ -16,8 +16,8 @@ INTGEMM_SSE2 TEST_CASE("Unquantize SSE2",) { auto postproc = PostprocessImpl<Unquantize, CPUType::SSE2>(Unquantize(0.5f)); auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0); - output.as<__m128>()[0] = output_tmp.pack0123; - output.as<__m128>()[1] = output_tmp.pack4567; + output.as<__m128>()[0] = output_tmp.first; + output.as<__m128>()[1] = output_tmp.second; CHECK(output[0] == -1.0f); // input = -2 CHECK(output[1] == -0.5f); // input = -1 diff --git a/vec_traits.h b/vec_traits.h new file mode 100644 index 0000000..4bf369d --- /dev/null +++ b/vec_traits.h @@ -0,0 +1,32 @@ +#pragma once + +#include "types.h" + +namespace intgemm { + +/* + * Vector traits + */ +template <CPUType CPUType_, typename ElemType_> struct vector_s; +template <> struct vector_s<CPUType::SSE2, int> { using type = __m128i; }; +template <> struct vector_s<CPUType::SSE2, float> { using type = __m128; }; +template <> struct vector_s<CPUType::SSE2, double> { using type = __m128d; }; +template <> struct vector_s<CPUType::AVX2, int> { using type = __m256i; }; +template <> struct vector_s<CPUType::AVX2, float> { using type = __m256; }; +template <> struct vector_s<CPUType::AVX2, double> { using type = __m256d; }; +template <> struct vector_s<CPUType::AVX512BW, int> { using type = __m512i; }; +template <> struct vector_s<CPUType::AVX512BW, float> { using type = __m512; }; +template <> struct vector_s<CPUType::AVX512BW, double> { using type = __m512d; }; + +template <CPUType CPUType_, typename ElemType_> +using vector_t = typename vector_s<CPUType_, ElemType_>::type; + +template <CPUType CPUType_, typename ElemType_> +struct dvector_t { + using type = vector_t<CPUType_, ElemType_>; + + type first; + type second; +}; + +} diff --git a/vec_utils.h b/vec_utils.h index acb7d6e..a5f1469 100644 --- a/vec_utils.h +++ b/vec_utils.h @@ -4,14 +4,6 @@ namespace intgemm { -struct RegisterPair128i { - __m128i pack0123, pack4567; -}; - -struct RegisterPair128 { - __m128 pack0123, pack4567; -}; - /* * * Quantize |