From 79368b121f29a0c12beada95473655220ead8827 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sat, 26 Jan 2019 14:14:57 -0800 Subject: add gelu activation --- src/graph/node_operators_unary.h | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 63756da3..190fa947 100755 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -343,31 +343,52 @@ private: * in an expression graph. * * This node implements the activation function - * \f$ f(x) = x \cdot \sigma(x) \f$ + * \f$ f(x) = x \cdot \sigma(bx) \f$ * and its derivative - * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ . + * \f$ f^\prime(x) = bf(x) + \sigma(bx)(1 - bf(x)) \f$ . * */ struct SwishNodeOp : public UnaryNodeOp { - SwishNodeOp(Expr a) : UnaryNodeOp(a) {} + SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {} NodeOps forwardOps() override { using namespace functional; - return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))}; + return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; - // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) - return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)), + // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x))) + return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val(), // _2 := x - val_ // _3 := f(x) = x*sigma(x) + val_ // _3 := f(x) = x*sigmoid(b*x) ))}; } const std::string type() override { return "swish"; } + + virtual size_t hash() override { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + util::hash_combine(hash_, b_); + } + return hash_; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + Ptr cnode = std::dynamic_pointer_cast(node); + if(!cnode) + return false; + if(b_ != cnode->b_) + return false; + return true; + } + + float b_; }; struct SoftmaxNodeOp : public UnaryNodeOp { -- cgit v1.2.3