Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-22 22:18:30 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-22 22:18:30 +0300
commita7891fe98a9cc23c3861accb13f4da5e16f8bebb (patch)
treea058ce409e1fc4408eed25b3b2074bd8f8653571 /src/graph/node_operators_unary.h
parentf0d728d752d7c15aa6d277fdb658db6b78993fb0 (diff)
swish
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h16
1 files changed, 16 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index ddac6378..3eb1c7ae 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -228,6 +228,22 @@ struct ReLUNodeOp : public UnaryNodeOp {
const std::string type() { return "ReLU"; }
};
+struct SwishNodeOp : public UnaryNodeOp {
+ template <typename... Args>
+ SwishNodeOp(Args... args) : UnaryNodeOp(args...) {}
+
+ NodeOps forwardOps() {
+ return {NodeOp(Element(_1 = _2 * Sigma(_2), val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(
+ Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), child(0)->grad(), adj_, child(0)->val(), val_))};
+ }
+
+ const std::string type() { return "swish"; }
+};
+
struct SoftmaxNodeOp : public NaryNodeOp {
template <typename... Args>
SoftmaxNodeOp(Expr a, Args... args)