diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-06-21 19:55:52 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-06-25 15:34:05 +0300 |
commit | 968713e30ef28b9422b445733080c33641722e33 (patch) | |
tree | 257fd271a2ca303ca89f739a0b8cd70568c40a3d | |
parent | c8e0a7db25c521589c1f1894f67789160e82bef8 (diff) |
Add unit test for AddBias postprocess
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | test/add_bias_test.cc | 95 |
2 files changed, 96 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 2061bf2..070a8b7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,7 @@ endforeach() include_directories(.) add_executable(tests + test/add_bias_test.cc test/multiply_test.cc test/pipeline_test.cc test/quantize_test.cc diff --git a/test/add_bias_test.cc b/test/add_bias_test.cc new file mode 100644 index 0000000..8d8b46c --- /dev/null +++ b/test/add_bias_test.cc @@ -0,0 +1,95 @@ +#include "3rd_party/catch.hpp" +#include "aligned.h" +#include "postprocess.h" + +#include <numeric> + +namespace intgemm { + +INTGEMM_SSE2 TEST_CASE("AddBias SSE2",) { + if (kCPU < CPUType::SSE2) + return; + + AlignedVector<float> input(8); + AlignedVector<float> bias(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -2); + std::iota(bias.begin(), bias.end(), 0); + + auto postproc = PostprocessImpl<AddBias, CPUType::SSE2>(AddBias(bias.begin(), bias.size())); + auto output_tmp = postproc.run({input.as<__m128>()[0], input.as<__m128>()[1]}, 0); + output.as<__m128>()[0] = output_tmp.pack0123; + output.as<__m128>()[1] = output_tmp.pack4567; + + CHECK(output[0] == -2.f); // input = -2, bias = 0 + CHECK(output[1] == 0.f); // input = -1, bias = 1 + CHECK(output[2] == 2.f); // input = 0, bias = 2 + CHECK(output[3] == 4.f); // input = 1, bias = 3 + CHECK(output[4] == 6.f); // input = 2, bias = 4 + CHECK(output[5] == 8.f); // input = 3, bias = 5 + CHECK(output[6] == 10.f); // input = 4, bias = 6 + CHECK(output[7] == 12.f); // input = 5, bias = 7 +} + +INTGEMM_AVX2 TEST_CASE("AddBias AVX2",) { + if (kCPU < CPUType::AVX2) + return; + + AlignedVector<float> input(8); + AlignedVector<float> bias(8); + AlignedVector<float> output(8); + + std::iota(input.begin(), input.end(), -4); + std::iota(bias.begin(), bias.end(), 0); + + auto postproc = PostprocessImpl<AddBias, CPUType::AVX2>(AddBias(bias.begin(), bias.size())); + *output.as<__m256>() = postproc.run(*input.as<__m256>(), 0); + + CHECK(output[0] == -4.f); // input = -4, bias = 0 + CHECK(output[1] == -2.f); // input = -3, bias = 1 + CHECK(output[2] == 0.f); // input = -2, bias = 2 + CHECK(output[3] == 2.f); // input = -1, bias = 3 + CHECK(output[4] == 4.f); // input = 0, bias = 4 + CHECK(output[5] == 6.f); // input = 1, bias = 5 + CHECK(output[6] == 8.f); // input = 2, bias = 6 + CHECK(output[7] == 10.f); // input = 3, bias = 7 +} + +#ifndef INTGEMM_NO_AVX512 + +INTGEMM_AVX512BW TEST_CASE("AddBias AVX512",) { + if (kCPU < CPUType::AVX512BW) + return; + + AlignedVector<float> input(16); + AlignedVector<float> bias(16); + AlignedVector<float> output(16); + + std::iota(input.begin(), input.end(), -8); + std::iota(bias.begin(), bias.end(), 0); + + auto postproc = PostprocessImpl<AddBias, CPUType::AVX512BW>(AddBias(bias.begin(), bias.size())); + *output.as<__m512>() = postproc.run(*input.as<__m512>(), 0); + + CHECK(output[0] == -8.f); // input = -8, bias = 0 + CHECK(output[1] == -6.f); // input = -7, bias = 1 + CHECK(output[2] == -4.f); // input = -6, bias = 2 + CHECK(output[3] == -2.f); // input = -5, bias = 3 + CHECK(output[4] == 0.f); // input = -4, bias = 4 + CHECK(output[5] == 2.f); // input = -3, bias = 5 + CHECK(output[6] == 4.f); // input = -2, bias = 6 + CHECK(output[7] == 6.f); // input = -1, bias = 7 + CHECK(output[8] == 8.f); // input = 0, bias = 8 + CHECK(output[9] == 10.f); // input = 1, bias = 9 + CHECK(output[10] == 12.f); // input = 2, bias = 10 + CHECK(output[11] == 14.f); // input = 3, bias = 11 + CHECK(output[12] == 16.f); // input = 4, bias = 12 + CHECK(output[13] == 18.f); // input = 5, bias = 13 + CHECK(output[14] == 20.f); // input = 6, bias = 14 + CHECK(output[15] == 22.f); // input = 7, bias = 15 +} + +#endif + +} |