Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-08 17:51:41 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-09 22:39:21 +0300
commit7e514a4d1178ddeae0cf38fa29f5ca758abf8a9a (patch)
tree03e7da533b58b66a8f40dea21915ad9a66ee6e9b
parent5466238858becaec459d154137dbd2d79baa0d3d (diff)
Add simple vector traits
-rw-r--r--multiply.h15
-rw-r--r--postprocess.h25
-rw-r--r--test/postprocess/add_bias_test.cc4
-rw-r--r--test/postprocess/relu_test.cc4
-rw-r--r--test/postprocess/unquantize_test.cc4
-rw-r--r--vec_traits.h32
-rw-r--r--vec_utils.h8
7 files changed, 59 insertions, 33 deletions
diff --git a/multiply.h b/multiply.h
index 420e815..dff06f3 100644
--- a/multiply.h
+++ b/multiply.h
@@ -4,12 +4,13 @@
#include "intrinsics.h"
#include "postprocess_pipeline.h"
#include "vec_utils.h"
+#include "vec_traits.h"
namespace intgemm {
-INTGEMM_SSE2 static inline void writer(float* C, Index offset, RegisterPair128 result) {
- *reinterpret_cast<__m128*>(C + offset) = result.pack0123;
- *reinterpret_cast<__m128*>(C + offset + 4) = result.pack4567;
+INTGEMM_SSE2 static inline void writer(float* C, Index offset, dvector_t<CPUType::SSE2, float> result) {
+ *reinterpret_cast<__m128*>(C + offset) = result.first;
+ *reinterpret_cast<__m128*>(C + offset + 4) = result.second;
}
INTGEMM_AVX2 static inline void writer(float* C, Index offset, __m256 result) {
@@ -33,11 +34,11 @@ INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) {
return *reinterpret_cast<float*>(&a);
}
-INTGEMM_SSE2 static inline RegisterPair128i PermuteSummer(__m128i pack0123, __m128i pack4567) {
+INTGEMM_SSE2 static inline dvector_t<CPUType::SSE2, int> PermuteSummer(__m128i pack0123, __m128i pack4567) {
// No op for 128 bits: already reduced fully.
- RegisterPair128i ret;
- ret.pack0123 = pack0123;
- ret.pack4567 = pack4567;
+ dvector_t<CPUType::SSE2, int> ret;
+ ret.first = pack0123;
+ ret.second = pack4567;
return ret;
}
diff --git a/postprocess.h b/postprocess.h
index 7835b2b..53c5a3e 100644
--- a/postprocess.h
+++ b/postprocess.h
@@ -4,6 +4,7 @@
#include "postprocess_pipeline.h"
#include "types.h"
#include "vec_utils.h"
+#include "vec_traits.h"
// TODO: We support some postprocess in few variations e.g. we support ReLU for
// float -> float, int8 -> int8, int16 -> int16. Maybe it would be a good idea
@@ -24,8 +25,8 @@ public:
template <>
class PostprocessImpl<Unquantize, CPUType::SSE2> {
public:
- using InputRegister = RegisterPair128i;
- using OutputRegister = RegisterPair128;
+ using InputRegister = dvector_t<CPUType::SSE2, int>;
+ using OutputRegister = dvector_t<CPUType::SSE2, float>;
INTGEMM_SSE2 PostprocessImpl(const Unquantize& config) {
unquantize_multiplier = set1_ps<__m128>(config.unquantize_multiplier);
@@ -33,8 +34,8 @@ public:
INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
return {
- mul_ps(cvtepi32_ps(input.pack0123), unquantize_multiplier),
- mul_ps(cvtepi32_ps(input.pack4567), unquantize_multiplier),
+ mul_ps(cvtepi32_ps(input.first), unquantize_multiplier),
+ mul_ps(cvtepi32_ps(input.second), unquantize_multiplier),
};
}
@@ -96,8 +97,8 @@ public:
template <>
class PostprocessImpl<AddBias, CPUType::SSE2> {
public:
- using InputRegister = RegisterPair128;
- using OutputRegister = RegisterPair128;
+ using InputRegister = dvector_t<CPUType::SSE2, float>;
+ using OutputRegister = dvector_t<CPUType::SSE2, float>;
PostprocessImpl(const AddBias& config) : config(config) {}
@@ -105,8 +106,8 @@ public:
auto bias_term0123 = *reinterpret_cast<const __m128*>(config.bias + (offset % config.length));
auto bias_term4567 = *reinterpret_cast<const __m128*>(config.bias + (offset % config.length) + 4);
return {
- add_ps(input.pack0123, bias_term0123),
- add_ps(input.pack4567, bias_term4567),
+ add_ps(input.first, bias_term0123),
+ add_ps(input.second, bias_term4567),
};
}
@@ -160,16 +161,16 @@ class ReLU {};
template <>
class PostprocessImpl<ReLU, CPUType::SSE2> {
public:
- using InputRegister = RegisterPair128;
- using OutputRegister = RegisterPair128;
+ using InputRegister = dvector_t<CPUType::SSE2, float>;
+ using OutputRegister = dvector_t<CPUType::SSE2, float>;
PostprocessImpl(const ReLU& config) {}
INTGEMM_SSE2 inline OutputRegister run(InputRegister input, Index offset) {
static const auto const_zero = set1_ps<__m128>(0.f);
return {
- max_ps(const_zero, input.pack0123),
- max_ps(const_zero, input.pack4567),
+ max_ps(const_zero, input.first),
+ max_ps(const_zero, input.second),
};
}
};
diff --git a/test/postprocess/add_bias_test.cc b/test/postprocess/add_bias_test.cc
index 5e893ea..3bc7f74 100644
--- a/test/postprocess/add_bias_test.cc
+++ b/test/postprocess/add_bias_test.cc
@@ -19,8 +19,8 @@ INTGEMM_SSE2 TEST_CASE("AddBias SSE2",) {
auto postproc = PostprocessImpl<AddBias, CPUType::SSE2>(AddBias(bias.begin(), bias.size()));
auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0);
- output.as<__m128>()[0] = output_tmp.pack0123;
- output.as<__m128>()[1] = output_tmp.pack4567;
+ output.as<__m128>()[0] = output_tmp.first;
+ output.as<__m128>()[1] = output_tmp.second;
CHECK(output[0] == -2.f); // input = -2, bias = 0
CHECK(output[1] == 0.f); // input = -1, bias = 1
diff --git a/test/postprocess/relu_test.cc b/test/postprocess/relu_test.cc
index af6677e..a560790 100644
--- a/test/postprocess/relu_test.cc
+++ b/test/postprocess/relu_test.cc
@@ -19,8 +19,8 @@ INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU());
auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0);
- output.as<__m128>()[0] = output_tmp.pack0123;
- output.as<__m128>()[1] = output_tmp.pack4567;
+ output.as<__m128>()[0] = output_tmp.first;
+ output.as<__m128>()[1] = output_tmp.second;
CHECK(output[0] == 0.f); // input = -2
CHECK(output[1] == 0.f); // input = -1
diff --git a/test/postprocess/unquantize_test.cc b/test/postprocess/unquantize_test.cc
index c33b909..45e6bc4 100644
--- a/test/postprocess/unquantize_test.cc
+++ b/test/postprocess/unquantize_test.cc
@@ -16,8 +16,8 @@ INTGEMM_SSE2 TEST_CASE("Unquantize SSE2",) {
auto postproc = PostprocessImpl<Unquantize, CPUType::SSE2>(Unquantize(0.5f));
auto output_tmp = postproc.run({input.as<__m128i>()[0], input.as<__m128i>()[1]}, 0);
- output.as<__m128>()[0] = output_tmp.pack0123;
- output.as<__m128>()[1] = output_tmp.pack4567;
+ output.as<__m128>()[0] = output_tmp.first;
+ output.as<__m128>()[1] = output_tmp.second;
CHECK(output[0] == -1.0f); // input = -2
CHECK(output[1] == -0.5f); // input = -1
diff --git a/vec_traits.h b/vec_traits.h
new file mode 100644
index 0000000..4bf369d
--- /dev/null
+++ b/vec_traits.h
@@ -0,0 +1,32 @@
+#pragma once
+
+#include "types.h"
+
+namespace intgemm {
+
+/*
+ * Vector traits
+ */
+template <CPUType CPUType_, typename ElemType_> struct vector_s;
+template <> struct vector_s<CPUType::SSE2, int> { using type = __m128i; };
+template <> struct vector_s<CPUType::SSE2, float> { using type = __m128; };
+template <> struct vector_s<CPUType::SSE2, double> { using type = __m128d; };
+template <> struct vector_s<CPUType::AVX2, int> { using type = __m256i; };
+template <> struct vector_s<CPUType::AVX2, float> { using type = __m256; };
+template <> struct vector_s<CPUType::AVX2, double> { using type = __m256d; };
+template <> struct vector_s<CPUType::AVX512BW, int> { using type = __m512i; };
+template <> struct vector_s<CPUType::AVX512BW, float> { using type = __m512; };
+template <> struct vector_s<CPUType::AVX512BW, double> { using type = __m512d; };
+
+template <CPUType CPUType_, typename ElemType_>
+using vector_t = typename vector_s<CPUType_, ElemType_>::type;
+
+template <CPUType CPUType_, typename ElemType_>
+struct dvector_t {
+ using type = vector_t<CPUType_, ElemType_>;
+
+ type first;
+ type second;
+};
+
+}
diff --git a/vec_utils.h b/vec_utils.h
index acb7d6e..a5f1469 100644
--- a/vec_utils.h
+++ b/vec_utils.h
@@ -4,14 +4,6 @@
namespace intgemm {
-struct RegisterPair128i {
- __m128i pack0123, pack4567;
-};
-
-struct RegisterPair128 {
- __m128 pack0123, pack4567;
-};
-
/*
*
* Quantize