Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aligned.h3
-rw-r--r--test/multiply_test.cc4
-rw-r--r--test/pipeline_test.cc77
-rw-r--r--test/relu_test.cc105
-rw-r--r--test/sigmoid_test.cc29
-rw-r--r--test/tanh_test.cc31
6 files changed, 121 insertions, 128 deletions
diff --git a/aligned.h b/aligned.h
index 7514000..6795788 100644
--- a/aligned.h
+++ b/aligned.h
@@ -22,6 +22,9 @@ template <class T> class AlignedVector {
T *end() { return mem_ + size_; }
const T *end() const { return mem_ + size_; }
+ template <typename ReturnType>
+ ReturnType *as() { return reinterpret_cast<ReturnType*>(mem_); }
+
private:
T *mem_;
std::size_t size_;
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 82062fe..f88a73a 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -61,7 +61,7 @@ INTGEMM_SSE2 TEST_CASE("Transpose 16", "[transpose]") {
SlowTranspose(input.begin(), ref.begin(), N, N);
// Overwrite input.
- __m128i *t = reinterpret_cast<__m128i*>(input.begin());
+ __m128i *t = input.as<__m128i>();
Transpose16InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7]);
for (int16_t i = 0; i < input.size(); ++i) {
@@ -79,7 +79,7 @@ INTGEMM_SSSE3 TEST_CASE("Transpose 8", "[transpose]") {
SlowTranspose(input.begin(), ref.begin(), N, N);
// Overwrite input.
- __m128i *t = reinterpret_cast<__m128i*>(input.begin());
+ __m128i *t = input.as<__m128i>();
Transpose8InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8], t[9], t[10], t[11], t[12], t[13], t[14], t[15]);
for (int i = 0; i < input.size(); ++i) {
diff --git a/test/pipeline_test.cc b/test/pipeline_test.cc
index 1b8c21d..8d60cff 100644
--- a/test/pipeline_test.cc
+++ b/test/pipeline_test.cc
@@ -1,4 +1,5 @@
#include "3rd_party/catch.hpp"
+#include "aligned.h"
#include "postprocess.h"
#include <numeric>
@@ -9,62 +10,54 @@ INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") {
if (kCPU < CPUType::AVX2)
return;
- __m256i input;
- __m256 output;
+ AlignedVector<int32_t> input(8);
+ AlignedVector<float> output(8);
- auto raw_input = reinterpret_cast<int*>(&input);
- std::iota(raw_input, raw_input + 8, -2);
-
- auto raw_output = reinterpret_cast<float*>(&output);
- std::fill(raw_output, raw_output + 8, 42);
+ std::iota(input.begin(), input.end(), -2);
auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
- output = inited_pipeline.run(input, 0);
-
- CHECK(raw_output[0] == 0.0f); // input = -2
- CHECK(raw_output[1] == 0.0f); // input = -1
- CHECK(raw_output[2] == 0.0f); // input = 0
- CHECK(raw_output[3] == 0.5f); // input = 1
- CHECK(raw_output[4] == 1.0f); // input = 2
- CHECK(raw_output[5] == 1.5f); // input = 3
- CHECK(raw_output[6] == 2.0f); // input = 4
- CHECK(raw_output[7] == 2.5f); // input = 5
+ *output.as<__m256>() = inited_pipeline.run(*input.as<__m256i>(), 0);
+
+ CHECK(output[0] == 0.0f); // input = -2
+ CHECK(output[1] == 0.0f); // input = -1
+ CHECK(output[2] == 0.0f); // input = 0
+ CHECK(output[3] == 0.5f); // input = 1
+ CHECK(output[4] == 1.0f); // input = 2
+ CHECK(output[5] == 1.5f); // input = 3
+ CHECK(output[6] == 2.0f); // input = 4
+ CHECK(output[7] == 2.5f); // input = 5
}
INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-ReLU") {
if (kCPU < CPUType::AVX2)
return;
- __m256i input[2];
- __m256 output[2];
+ AlignedVector<int32_t> input(16);
+ AlignedVector<float> output(16);
- auto raw_input = reinterpret_cast<int*>(input);
- std::iota(raw_input, raw_input + 16, -8);
-
- auto raw_output = reinterpret_cast<float*>(output);
- std::fill(raw_output, raw_output + 16, 42);
+ std::iota(input.begin(), input.end(), -8);
auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
- inited_pipeline.run(input, 2, output);
-
- CHECK(raw_output[0] == 0.f); // input = -8
- CHECK(raw_output[1] == 0.f); // input = -7
- CHECK(raw_output[2] == 0.f); // input = -6
- CHECK(raw_output[3] == 0.f); // input = -5
- CHECK(raw_output[4] == 0.f); // input = -4
- CHECK(raw_output[5] == 0.f); // input = -3
- CHECK(raw_output[6] == 0.f); // input = -2
- CHECK(raw_output[7] == 0.f); // input = -1
- CHECK(raw_output[8] == 0.0f); // input = 0
- CHECK(raw_output[9] == 0.5f); // input = 1
- CHECK(raw_output[10] == 1.0f); // input = 2
- CHECK(raw_output[11] == 1.5f); // input = 3
- CHECK(raw_output[12] == 2.0f); // input = 4
- CHECK(raw_output[13] == 2.5f); // input = 5
- CHECK(raw_output[14] == 3.0f); // input = 6
- CHECK(raw_output[15] == 3.5f); // input = 7
+ inited_pipeline.run(input.as<__m256i>(), 2, output.as<__m256>());
+
+ CHECK(output[0] == 0.f); // input = -8
+ CHECK(output[1] == 0.f); // input = -7
+ CHECK(output[2] == 0.f); // input = -6
+ CHECK(output[3] == 0.f); // input = -5
+ CHECK(output[4] == 0.f); // input = -4
+ CHECK(output[5] == 0.f); // input = -3
+ CHECK(output[6] == 0.f); // input = -2
+ CHECK(output[7] == 0.f); // input = -1
+ CHECK(output[8] == 0.0f); // input = 0
+ CHECK(output[9] == 0.5f); // input = 1
+ CHECK(output[10] == 1.0f); // input = 2
+ CHECK(output[11] == 1.5f); // input = 3
+ CHECK(output[12] == 2.0f); // input = 4
+ CHECK(output[13] == 2.5f); // input = 5
+ CHECK(output[14] == 3.0f); // input = 6
+ CHECK(output[15] == 3.5f); // input = 7
}
}
diff --git a/test/relu_test.cc b/test/relu_test.cc
index 183f415..fda7a2a 100644
--- a/test/relu_test.cc
+++ b/test/relu_test.cc
@@ -1,4 +1,5 @@
#include "3rd_party/catch.hpp"
+#include "aligned.h"
#include "postprocess.h"
#include <numeric>
@@ -9,47 +10,45 @@ INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
if (kCPU < CPUType::SSE2)
return;
- float raw_input[8];
- std::iota(raw_input, raw_input + 8, -2);
-
- RegisterPair128 input;
- input.pack0123 = *reinterpret_cast<__m128*>(raw_input);
- input.pack4567 = *reinterpret_cast<__m128*>(raw_input + 4);
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
+ std::iota(input.begin(), input.end(), -2);
auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU());
- auto output = postproc.run(input, 0);
- auto raw_output = reinterpret_cast<float*>(&output);
-
- CHECK(raw_output[0] == 0.f); // input = -2
- CHECK(raw_output[1] == 0.f); // input = -1
- CHECK(raw_output[2] == 0.f); // input = 0
- CHECK(raw_output[3] == 1.f); // input = 1
- CHECK(raw_output[4] == 2.f); // input = 2
- CHECK(raw_output[5] == 3.f); // input = 3
- CHECK(raw_output[6] == 4.f); // input = 4
- CHECK(raw_output[7] == 5.f); // input = 5
+ auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0);
+ output.as<__m128>()[0] = output_tmp.pack0123;
+ output.as<__m128>()[1] = output_tmp.pack4567;
+
+ CHECK(output[0] == 0.f); // input = -2
+ CHECK(output[1] == 0.f); // input = -1
+ CHECK(output[2] == 0.f); // input = 0
+ CHECK(output[3] == 1.f); // input = 1
+ CHECK(output[4] == 2.f); // input = 2
+ CHECK(output[5] == 3.f); // input = 3
+ CHECK(output[6] == 4.f); // input = 4
+ CHECK(output[7] == 5.f); // input = 5
}
INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) {
if (kCPU < CPUType::AVX2)
return;
- float raw_input[8];
- std::iota(raw_input, raw_input + 8, -4);
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -4);
- auto input = *reinterpret_cast<__m256*>(raw_input);
auto postproc = PostprocessImpl<ReLU, CPUType::AVX2>(ReLU());
- auto output = postproc.run(input, 0);
- auto raw_output = reinterpret_cast<float*>(&output);
-
- CHECK(raw_output[0] == 0.f); // input = -4
- CHECK(raw_output[1] == 0.f); // input = -3
- CHECK(raw_output[2] == 0.f); // input = -2
- CHECK(raw_output[3] == 0.f); // input = -1
- CHECK(raw_output[4] == 0.f); // input = 0
- CHECK(raw_output[5] == 1.f); // input = 1
- CHECK(raw_output[6] == 2.f); // input = 2
- CHECK(raw_output[7] == 3.f); // input = 3
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK(output[0] == 0.f); // input = -4
+ CHECK(output[1] == 0.f); // input = -3
+ CHECK(output[2] == 0.f); // input = -2
+ CHECK(output[3] == 0.f); // input = -1
+ CHECK(output[4] == 0.f); // input = 0
+ CHECK(output[5] == 1.f); // input = 1
+ CHECK(output[6] == 2.f); // input = 2
+ CHECK(output[7] == 3.f); // input = 3
}
#ifndef INTGEMM_NO_AVX512
@@ -58,30 +57,30 @@ INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) {
if (kCPU < CPUType::AVX512BW)
return;
- float raw_input[16];
- std::iota(raw_input, raw_input + 16, -8);
+ AlignedVector<float> input(16);
+ AlignedVector<float> output(16);
+
+ std::iota(input.begin(), input.end(), -8);
- auto input = *reinterpret_cast<__m512*>(raw_input);
auto postproc = PostprocessImpl<ReLU, CPUType::AVX512BW>(ReLU());
- auto output = postproc.run(input, 0);
- auto raw_output = reinterpret_cast<float*>(&output);
-
- CHECK(raw_output[0] == 0.f); // input = -8
- CHECK(raw_output[1] == 0.f); // input = -7
- CHECK(raw_output[2] == 0.f); // input = -6
- CHECK(raw_output[3] == 0.f); // input = -5
- CHECK(raw_output[4] == 0.f); // input = -4
- CHECK(raw_output[5] == 0.f); // input = -3
- CHECK(raw_output[6] == 0.f); // input = -2
- CHECK(raw_output[7] == 0.f); // input = -1
- CHECK(raw_output[8] == 0.f); // input = 0
- CHECK(raw_output[9] == 1.f); // input = 1
- CHECK(raw_output[10] == 2.f); // input = 2
- CHECK(raw_output[11] == 3.f); // input = 3
- CHECK(raw_output[12] == 4.f); // input = 4
- CHECK(raw_output[13] == 5.f); // input = 5
- CHECK(raw_output[14] == 6.f); // input = 6
- CHECK(raw_output[15] == 7.f); // input = 7
+ *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0);
+
+ CHECK(output[0] == 0.f); // input = -8
+ CHECK(output[1] == 0.f); // input = -7
+ CHECK(output[2] == 0.f); // input = -6
+ CHECK(output[3] == 0.f); // input = -5
+ CHECK(output[4] == 0.f); // input = -4
+ CHECK(output[5] == 0.f); // input = -3
+ CHECK(output[6] == 0.f); // input = -2
+ CHECK(output[7] == 0.f); // input = -1
+ CHECK(output[8] == 0.f); // input = 0
+ CHECK(output[9] == 1.f); // input = 1
+ CHECK(output[10] == 2.f); // input = 2
+ CHECK(output[11] == 3.f); // input = 3
+ CHECK(output[12] == 4.f); // input = 4
+ CHECK(output[13] == 5.f); // input = 5
+ CHECK(output[14] == 6.f); // input = 6
+ CHECK(output[15] == 7.f); // input = 7
}
#endif
diff --git a/test/sigmoid_test.cc b/test/sigmoid_test.cc
index 86f85d4..fc50e37 100644
--- a/test/sigmoid_test.cc
+++ b/test/sigmoid_test.cc
@@ -1,4 +1,5 @@
#include "3rd_party/catch.hpp"
+#include "aligned.h"
#include "postprocess.h"
#include <numeric>
@@ -17,22 +18,22 @@ INTGEMM_AVX2 TEST_CASE("Sigmoid AVX2",) {
const float error_tolerance = 0.001f;
- __m256 input;
- auto raw = reinterpret_cast<float*>(&input);
- std::iota(raw, raw + 8, -4);
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -4);
auto postproc = PostprocessImpl<Sigmoid, CPUType::AVX2>(Sigmoid());
- auto output = postproc.run(input, 0);
- auto raw_output = reinterpret_cast<float*>(&output);
-
- CHECK_FLOAT(raw_output[0], 0.0179862f, error_tolerance); // input = -4
- CHECK_FLOAT(raw_output[1], 0.0474259f, error_tolerance); // input = -3
- CHECK_FLOAT(raw_output[2], 0.1192029f, error_tolerance); // input = -2
- CHECK_FLOAT(raw_output[3], 0.2689414f, error_tolerance); // input = -1
- CHECK_FLOAT(raw_output[4], 0.5f , error_tolerance); // input = 0
- CHECK_FLOAT(raw_output[5], 0.7310586f, error_tolerance); // input = 1
- CHECK_FLOAT(raw_output[6], 0.8807970f, error_tolerance); // input = 2
- CHECK_FLOAT(raw_output[7], 0.9525740f, error_tolerance); // input = 3
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK_FLOAT(output[0], 0.0179862f, error_tolerance); // input = -4
+ CHECK_FLOAT(output[1], 0.0474259f, error_tolerance); // input = -3
+ CHECK_FLOAT(output[2], 0.1192029f, error_tolerance); // input = -2
+ CHECK_FLOAT(output[3], 0.2689414f, error_tolerance); // input = -1
+ CHECK_FLOAT(output[4], 0.5f , error_tolerance); // input = 0
+ CHECK_FLOAT(output[5], 0.7310586f, error_tolerance); // input = 1
+ CHECK_FLOAT(output[6], 0.8807970f, error_tolerance); // input = 2
+ CHECK_FLOAT(output[7], 0.9525740f, error_tolerance); // input = 3
}
}
diff --git a/test/tanh_test.cc b/test/tanh_test.cc
index 72e2555..54c34fd 100644
--- a/test/tanh_test.cc
+++ b/test/tanh_test.cc
@@ -1,4 +1,5 @@
#include "3rd_party/catch.hpp"
+#include "aligned.h"
#include "postprocess.h"
#include <numeric>
@@ -17,26 +18,22 @@ INTGEMM_AVX2 TEST_CASE("Tanh AVX2",) {
const float error_tolerance = 0.001f;
- __m256 input;
+ AlignedVector<float> input(8);
+ AlignedVector<float> output(8);
- { // fill
- auto raw = reinterpret_cast<float*>(&input);
- int n = -4;
- std::generate(raw, raw + 8, [&n] () { return n++ / 4.f; });
- }
+ std::generate(input.begin(), input.end(), [] () { static int n = -4; return n++ / 4.f; });
auto postproc = PostprocessImpl<Tanh, CPUType::AVX2>(Tanh());
- auto output = postproc.run(input, 0);
- auto raw_output = reinterpret_cast<float*>(&output);
-
- CHECK_FLOAT(raw_output[0], -0.7615942f, error_tolerance); // input = -1
- CHECK_FLOAT(raw_output[1], -0.6351490f, error_tolerance); // input = -0.75
- CHECK_FLOAT(raw_output[2], -0.4621172f, error_tolerance); // input = -0.5
- CHECK_FLOAT(raw_output[3], -0.2449187f, error_tolerance); // input = -0.25
- CHECK_FLOAT(raw_output[4], 0.0f , error_tolerance); // input = 0
- CHECK_FLOAT(raw_output[5], 0.2449187f, error_tolerance); // input = 0.25
- CHECK_FLOAT(raw_output[6], 0.4621172f, error_tolerance); // input = 0.5
- CHECK_FLOAT(raw_output[7], 0.6351490f, error_tolerance); // input = 0.75
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK_FLOAT(output[0], -0.7615942f, error_tolerance); // input = -1
+ CHECK_FLOAT(output[1], -0.6351490f, error_tolerance); // input = -0.75
+ CHECK_FLOAT(output[2], -0.4621172f, error_tolerance); // input = -0.5
+ CHECK_FLOAT(output[3], -0.2449187f, error_tolerance); // input = -0.25
+ CHECK_FLOAT(output[4], 0.0f , error_tolerance); // input = 0
+ CHECK_FLOAT(output[5], 0.2449187f, error_tolerance); // input = 0.25
+ CHECK_FLOAT(output[6], 0.4621172f, error_tolerance); // input = 0.5
+ CHECK_FLOAT(output[7], 0.6351490f, error_tolerance); // input = 0.75
}
}