diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-22 22:18:30 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-22 22:18:30 +0300 |
commit | a7891fe98a9cc23c3861accb13f4da5e16f8bebb (patch) | |
tree | a058ce409e1fc4408eed25b3b2074bd8f8653571 /src/graph/node_operators_unary.h | |
parent | f0d728d752d7c15aa6d277fdb658db6b78993fb0 (diff) |
swish
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index ddac6378..3eb1c7ae 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -228,6 +228,22 @@ struct ReLUNodeOp : public UnaryNodeOp { const std::string type() { return "ReLU"; } }; +struct SwishNodeOp : public UnaryNodeOp { + template <typename... Args> + SwishNodeOp(Args... args) : UnaryNodeOp(args...) {} + + NodeOps forwardOps() { + return {NodeOp(Element(_1 = _2 * Sigma(_2), val_, child(0)->val()))}; + } + + NodeOps backwardOps() { + return {NodeOp( + Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), child(0)->grad(), adj_, child(0)->val(), val_))}; + } + + const std::string type() { return "swish"; } +}; + struct SoftmaxNodeOp : public NaryNodeOp { template <typename... Args> SoftmaxNodeOp(Expr a, Args... args) |