Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2019-01-27 01:14:57 +0300
committerMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2019-01-27 01:14:57 +0300
commit79368b121f29a0c12beada95473655220ead8827 (patch)
tree17b39372649b08ed2c2943a58562295a346f4d27 /src/graph/node_operators_unary.h
parent50d64de62cb3e47b912fdd3438ffacb566187263 (diff)
add gelu activation
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rwxr-xr-xsrc/graph/node_operators_unary.h35
1 files changed, 28 insertions, 7 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 63756da3..190fa947 100755
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -343,31 +343,52 @@ private:
* in an expression graph.
*
* This node implements the activation function
- * \f$ f(x) = x \cdot \sigma(x) \f$
+ * \f$ f(x) = x \cdot \sigma(bx) \f$
* and its derivative
- * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ .
+ * \f$ f^\prime(x) = bf(x) + \sigma(bx)(1 - bf(x)) \f$ .
*
*/
struct SwishNodeOp : public UnaryNodeOp {
- SwishNodeOp(Expr a) : UnaryNodeOp(a) {}
+ SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {}
NodeOps forwardOps() override {
using namespace functional;
- return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))};
+ return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
- // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) )
- return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)),
+ // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x)))
+ return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))),
child(0)->grad(), // dJ/dx
adj_, // _1 := dJ/df
child(0)->val(), // _2 := x
- val_ // _3 := f(x) = x*sigma(x)
+ val_ // _3 := f(x) = x*sigmoid(b*x)
))};
}
const std::string type() override { return "swish"; }
+
+ virtual size_t hash() override {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ util::hash_combine(hash_, b_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ Ptr<SwishNodeOp> cnode = std::dynamic_pointer_cast<SwishNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(b_ != cnode->b_)
+ return false;
+ return true;
+ }
+
+ float b_;
};
struct SoftmaxNodeOp : public UnaryNodeOp {