diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-05-29 17:08:11 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-06-18 16:38:54 +0300 |
commit | cc02a23473f888b06b59ecb7af3fa89d064b573f (patch) | |
tree | c12d3ac70111a6fb6f4dc7932dbc3c6b2a7a045c /test | |
parent | f0785bea3b42a8e5ab7e322b5ad0dc1e9018d65f (diff) |
Change postprocess API
From now, run function takes input vector and offset in dst buffer
Diffstat (limited to 'test')
-rw-r--r-- | test/multiply_test.cc | 2 | ||||
-rw-r--r-- | test/pipeline_test.cc | 2 | ||||
-rw-r--r-- | test/relu_test.cc | 6 |
3 files changed, 5 insertions, 5 deletions
diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 0c0becf..00256a8 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -416,7 +416,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index AlignedVector<float> test_C(A_rows * B_cols); - Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult), AddBias(bias.begin())), A_rows, width, B_cols); + Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult), AddBias(bias.begin(), B_cols)), A_rows, width, B_cols); AlignedVector<Integer> B_quant(B.size()); Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); diff --git a/test/pipeline_test.cc b/test/pipeline_test.cc index dc3d71e..8544e83 100644 --- a/test/pipeline_test.cc +++ b/test/pipeline_test.cc @@ -15,7 +15,7 @@ INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") { auto input = *reinterpret_cast<__m256i*>(raw_input); auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU()); auto inited_pipeline = InitPostprocessPipeline<CPU_AVX2>(pipeline); - auto output = inited_pipeline.run(input); + auto output = inited_pipeline.run(input, 0); float* raw_output = reinterpret_cast<float*>(&output); diff --git a/test/relu_test.cc b/test/relu_test.cc index 0c677c9..0a72a29 100644 --- a/test/relu_test.cc +++ b/test/relu_test.cc @@ -17,7 +17,7 @@ INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) { input.pack4567 = *reinterpret_cast<__m128*>(raw_input + 4); auto postproc = PostprocessImpl<ReLU, CPUType::CPU_SSE2>(ReLU()); - auto output = postproc.run(input); + auto output = postproc.run(input, 0); auto raw_output = reinterpret_cast<float*>(&output); CHECK(raw_output[0] == 0.f); // input = -2 @@ -39,7 +39,7 @@ INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) { auto input = *reinterpret_cast<__m256*>(raw_input); auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX2>(ReLU()); - auto output = postproc.run(input); + auto output = postproc.run(input, 0); auto raw_output = reinterpret_cast<float*>(&output); CHECK(raw_output[0] == 0.f); // input = -4 @@ -63,7 +63,7 @@ INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) { auto input = *reinterpret_cast<__m512*>(raw_input); auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX512BW>(ReLU()); - auto output = postproc.run(input); + auto output = postproc.run(input, 0); auto raw_output = reinterpret_cast<float*>(&output); CHECK(raw_output[0] == 0.f); // input = -8 |