Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensors/cpu/prod.cpp')
-rwxr-xr-xsrc/tensors/cpu/prod.cpp17
1 files changed, 17 insertions, 0 deletions
diff --git a/src/tensors/cpu/prod.cpp b/src/tensors/cpu/prod.cpp
index f77337d6..6e28bdd2 100755
--- a/src/tensors/cpu/prod.cpp
+++ b/src/tensors/cpu/prod.cpp
@@ -212,6 +212,23 @@ void ProdWithBias(marian::Tensor C,
cpu::integer::AddBias(C, bias);
}
+void Affine(marian::Tensor C,
+ Ptr<Allocator> /*allocator*/,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
+ const marian::Tensor& bias,
+ bool transA,
+ bool transB,
+ float beta,
+ float scalar,
+ bool reluPostprocess) {
+ using namespace functional;
+ ProdWithBias(C, A, B, bias, transA, transB, beta, scalar);
+ if(reluPostprocess)
+ cpu::Element(_1 = ReLU(_1), C); // @TODO: also fuse with AddBias
+}
+
+
void CSRProd(marian::Tensor C,
Ptr<Allocator> /*allocator*/,
const marian::Tensor& S_values,