diff options
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/expression_operators.cpp | 9 | ||||
-rwxr-xr-x | src/graph/expression_operators.h | 3 | ||||
-rwxr-xr-x | src/graph/node_operators_unary.h | 35 |
3 files changed, 40 insertions, 7 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index c0640edb..e558ffd0 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -51,6 +51,10 @@ Expr swish(Expr a) { return Expression<SwishNodeOp>(a); } +Expr gelu(Expr a) { + return Expression<SwishNodeOp>(a, 1.702f); +} + Expr operator-(Expr a) { return Expression<NegNodeOp>(a); }; @@ -529,6 +533,11 @@ Expr swish(const std::vector<Expr>& nodes) { return swish(nodes[0]); } +Expr gelu(const std::vector<Expr>& nodes) { + ABORT_IF(nodes.size() > 1, "Not implemented"); + return gelu(nodes[0]); +} + Expr tanh(const std::vector<Expr>& nodes) { return Expression<TanhNodeOp>(nodes); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 78aed834..0205fe86 100755 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -17,6 +17,9 @@ Expr sigmoid(const std::vector<Expr>&); Expr swish(Expr a); Expr swish(const std::vector<Expr>&); +Expr gelu(Expr a); +Expr gelu(const std::vector<Expr>&); + Expr tanh(const std::vector<Expr>&); template <typename... Args> diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 63756da3..190fa947 100755 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -343,31 +343,52 @@ private: * in an expression graph. * * This node implements the activation function - * \f$ f(x) = x \cdot \sigma(x) \f$ + * \f$ f(x) = x \cdot \sigma(bx) \f$ * and its derivative - * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ . + * \f$ f^\prime(x) = bf(x) + \sigma(bx)(1 - bf(x)) \f$ . * */ struct SwishNodeOp : public UnaryNodeOp { - SwishNodeOp(Expr a) : UnaryNodeOp(a) {} + SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {} NodeOps forwardOps() override { using namespace functional; - return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))}; + return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; - // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) - return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)), + // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x))) + return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val(), // _2 := x - val_ // _3 := f(x) = x*sigma(x) + val_ // _3 := f(x) = x*sigmoid(b*x) ))}; } const std::string type() override { return "swish"; } + + virtual size_t hash() override { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + util::hash_combine(hash_, b_); + } + return hash_; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + Ptr<SwishNodeOp> cnode = std::dynamic_pointer_cast<SwishNodeOp>(node); + if(!cnode) + return false; + if(b_ != cnode->b_) + return false; + return true; + } + + float b_; }; struct SoftmaxNodeOp : public UnaryNodeOp { |