Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-21 19:55:52 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-25 15:34:05 +0300
commit968713e30ef28b9422b445733080c33641722e33 (patch)
tree257fd271a2ca303ca89f739a0b8cd70568c40a3d
parentc8e0a7db25c521589c1f1894f67789160e82bef8 (diff)
Add unit test for AddBias postprocess
-rw-r--r--CMakeLists.txt1
-rw-r--r--test/add_bias_test.cc95
2 files changed, 96 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 2061bf2..070a8b7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -32,6 +32,7 @@ endforeach()
include_directories(.)
add_executable(tests
+ test/add_bias_test.cc
test/multiply_test.cc
test/pipeline_test.cc
test/quantize_test.cc
diff --git a/test/add_bias_test.cc b/test/add_bias_test.cc
new file mode 100644
index 0000000..8d8b46c
--- /dev/null
+++ b/test/add_bias_test.cc
@@ -0,0 +1,95 @@
+#include "3rd_party/catch.hpp"
+#include "aligned.h"
+#include "postprocess.h"
+
+#include <numeric>
+
+namespace intgemm {
+
+INTGEMM_SSE2 TEST_CASE("AddBias SSE2",) {
+ if (kCPU < CPUType::SSE2)
+ return;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> bias(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -2);
+ std::iota(bias.begin(), bias.end(), 0);
+
+ auto postproc = PostprocessImpl<AddBias, CPUType::SSE2>(AddBias(bias.begin(), bias.size()));
+ auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0);
+ output.as<__m128>()[0] = output_tmp.pack0123;
+ output.as<__m128>()[1] = output_tmp.pack4567;
+
+ CHECK(output[0] == -2.f); // input = -2, bias = 0
+ CHECK(output[1] == 0.f); // input = -1, bias = 1
+ CHECK(output[2] == 2.f); // input = 0, bias = 2
+ CHECK(output[3] == 4.f); // input = 1, bias = 3
+ CHECK(output[4] == 6.f); // input = 2, bias = 4
+ CHECK(output[5] == 8.f); // input = 3, bias = 5
+ CHECK(output[6] == 10.f); // input = 4, bias = 6
+ CHECK(output[7] == 12.f); // input = 5, bias = 7
+}
+
+INTGEMM_AVX2 TEST_CASE("AddBias AVX2",) {
+ if (kCPU < CPUType::AVX2)
+ return;
+
+ AlignedVector<float> input(8);
+ AlignedVector<float> bias(8);
+ AlignedVector<float> output(8);
+
+ std::iota(input.begin(), input.end(), -4);
+ std::iota(bias.begin(), bias.end(), 0);
+
+ auto postproc = PostprocessImpl<AddBias, CPUType::AVX2>(AddBias(bias.begin(), bias.size()));
+ *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0);
+
+ CHECK(output[0] == -4.f); // input = -4, bias = 0
+ CHECK(output[1] == -2.f); // input = -3, bias = 1
+ CHECK(output[2] == 0.f); // input = -2, bias = 2
+ CHECK(output[3] == 2.f); // input = -1, bias = 3
+ CHECK(output[4] == 4.f); // input = 0, bias = 4
+ CHECK(output[5] == 6.f); // input = 1, bias = 5
+ CHECK(output[6] == 8.f); // input = 2, bias = 6
+ CHECK(output[7] == 10.f); // input = 3, bias = 7
+}
+
+#ifndef INTGEMM_NO_AVX512
+
+INTGEMM_AVX512BW TEST_CASE("AddBias AVX512",) {
+ if (kCPU < CPUType::AVX512BW)
+ return;
+
+ AlignedVector<float> input(16);
+ AlignedVector<float> bias(16);
+ AlignedVector<float> output(16);
+
+ std::iota(input.begin(), input.end(), -8);
+ std::iota(bias.begin(), bias.end(), 0);
+
+ auto postproc = PostprocessImpl<AddBias, CPUType::AVX512BW>(AddBias(bias.begin(), bias.size()));
+ *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0);
+
+ CHECK(output[0] == -8.f); // input = -8, bias = 0
+ CHECK(output[1] == -6.f); // input = -7, bias = 1
+ CHECK(output[2] == -4.f); // input = -6, bias = 2
+ CHECK(output[3] == -2.f); // input = -5, bias = 3
+ CHECK(output[4] == 0.f); // input = -4, bias = 4
+ CHECK(output[5] == 2.f); // input = -3, bias = 5
+ CHECK(output[6] == 4.f); // input = -2, bias = 6
+ CHECK(output[7] == 6.f); // input = -1, bias = 7
+ CHECK(output[8] == 8.f); // input = 0, bias = 8
+ CHECK(output[9] == 10.f); // input = 1, bias = 9
+ CHECK(output[10] == 12.f); // input = 2, bias = 10
+ CHECK(output[11] == 14.f); // input = 3, bias = 11
+ CHECK(output[12] == 16.f); // input = 4, bias = 12
+ CHECK(output[13] == 18.f); // input = 5, bias = 13
+ CHECK(output[14] == 20.f); // input = 6, bias = 14
+ CHECK(output[15] == 22.f); // input = 7, bias = 15
+}
+
+#endif
+
+}