Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-05-29 17:08:11 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-06-18 16:38:54 +0300
commitcc02a23473f888b06b59ecb7af3fa89d064b573f (patch)
treec12d3ac70111a6fb6f4dc7932dbc3c6b2a7a045c /test
parentf0785bea3b42a8e5ab7e322b5ad0dc1e9018d65f (diff)
Change postprocess API
From now, run function takes input vector and offset in dst buffer
Diffstat (limited to 'test')
-rw-r--r--test/multiply_test.cc2
-rw-r--r--test/pipeline_test.cc2
-rw-r--r--test/relu_test.cc6
3 files changed, 5 insertions, 5 deletions
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 0c0becf..00256a8 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -416,7 +416,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
AlignedVector<float> test_C(A_rows * B_cols);
- Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult), AddBias(bias.begin())), A_rows, width, B_cols);
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), test_C.begin(), CreatePostprocessPipeline(Unquantize(unquant_mult), AddBias(bias.begin(), B_cols)), A_rows, width, B_cols);
AlignedVector<Integer> B_quant(B.size());
Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());
diff --git a/test/pipeline_test.cc b/test/pipeline_test.cc
index dc3d71e..8544e83 100644
--- a/test/pipeline_test.cc
+++ b/test/pipeline_test.cc
@@ -15,7 +15,7 @@ INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") {
auto input = *reinterpret_cast<__m256i*>(raw_input);
auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU());
auto inited_pipeline = InitPostprocessPipeline<CPU_AVX2>(pipeline);
- auto output = inited_pipeline.run(input);
+ auto output = inited_pipeline.run(input, 0);
float* raw_output = reinterpret_cast<float*>(&output);
diff --git a/test/relu_test.cc b/test/relu_test.cc
index 0c677c9..0a72a29 100644
--- a/test/relu_test.cc
+++ b/test/relu_test.cc
@@ -17,7 +17,7 @@ INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) {
input.pack4567 = *reinterpret_cast<__m128*>(raw_input + 4);
auto postproc = PostprocessImpl<ReLU, CPUType::CPU_SSE2>(ReLU());
- auto output = postproc.run(input);
+ auto output = postproc.run(input, 0);
auto raw_output = reinterpret_cast<float*>(&output);
CHECK(raw_output[0] == 0.f); // input = -2
@@ -39,7 +39,7 @@ INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) {
auto input = *reinterpret_cast<__m256*>(raw_input);
auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX2>(ReLU());
- auto output = postproc.run(input);
+ auto output = postproc.run(input, 0);
auto raw_output = reinterpret_cast<float*>(&output);
CHECK(raw_output[0] == 0.f); // input = -4
@@ -63,7 +63,7 @@ INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) {
auto input = *reinterpret_cast<__m512*>(raw_input);
auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX512BW>(ReLU());
- auto output = postproc.run(input);
+ auto output = postproc.run(input, 0);
auto raw_output = reinterpret_cast<float*>(&output);
CHECK(raw_output[0] == 0.f); // input = -8