diff options
author | Frank Seide <fseide@microsoft.com> | 2019-02-07 19:35:21 +0300 |
---|---|---|
committer | Frank Seide <fseide@microsoft.com> | 2019-02-07 19:35:21 +0300 |
commit | 5f498c9c658c65f89b85df2ef4ef6eaaa3f2ebf5 (patch) | |
tree | f3b627c7a5f7cb089b8283fd29c308915c5e6047 /src/graph | |
parent | 46e9565c9a06e54f4f05f50f722da726ae437df0 (diff) | |
parent | 9b54e7f1caa86b16da84638a20566e8d5451c170 (diff) |
merged from fseide/commentbeamsearch
Diffstat (limited to 'src/graph')
-rw-r--r--[-rwxr-xr-x] | src/graph/chainable.h | 0 | ||||
-rw-r--r--[-rwxr-xr-x] | src/graph/expression_graph.cpp | 0 | ||||
-rw-r--r--[-rwxr-xr-x] | src/graph/expression_graph.h | 0 | ||||
-rwxr-xr-x | src/graph/expression_operators.cpp | 11 | ||||
-rwxr-xr-x | src/graph/expression_operators.h | 3 | ||||
-rw-r--r--[-rwxr-xr-x] | src/graph/node.cpp | 0 | ||||
-rw-r--r--[-rwxr-xr-x] | src/graph/node.h | 0 | ||||
-rw-r--r--[-rwxr-xr-x] | src/graph/node_initializers.cpp | 0 | ||||
-rw-r--r--[-rwxr-xr-x] | src/graph/node_initializers.h | 0 | ||||
-rw-r--r--[-rwxr-xr-x] | src/graph/node_operators_binary.h | 0 | ||||
-rwxr-xr-x | src/graph/node_operators_unary.h | 35 | ||||
-rw-r--r--[-rwxr-xr-x] | src/graph/parameters.h | 0 |
12 files changed, 41 insertions, 8 deletions
diff --git a/src/graph/chainable.h b/src/graph/chainable.h index 2679843e..2679843e 100755..100644 --- a/src/graph/chainable.h +++ b/src/graph/chainable.h diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp index b3c237d1..b3c237d1 100755..100644 --- a/src/graph/expression_graph.cpp +++ b/src/graph/expression_graph.cpp diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h index dfd00a9a..dfd00a9a 100755..100644 --- a/src/graph/expression_graph.h +++ b/src/graph/expression_graph.h diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 866c4b2c..5c37beef 100755 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -51,6 +51,10 @@ Expr swish(Expr a) { return Expression<SwishNodeOp>(a); } +Expr gelu(Expr a) { + return Expression<SwishNodeOp>(a, 1.702f); +} + Expr operator-(Expr a) { return Expression<NegNodeOp>(a); }; @@ -532,7 +536,7 @@ Expr swapAxes(Expr x, int axis1, int axis2) return x; // TODO: This is code dup from transpose(x). Implement transpose(x) as swapAxes(x, 0, 1) std::vector<int> axes(x->shape().size()); - for (int i = 0; i < axes.size(); ++i) + for (int i = 0; i < axes.size(); ++i) // @TODO: use std::iota() axes[i] = i; std::swap(axes[axis1], axes[axis2]); return transpose(x, axes); @@ -552,6 +556,11 @@ Expr swish(const std::vector<Expr>& nodes) { return swish(nodes[0]); } +Expr gelu(const std::vector<Expr>& nodes) { + ABORT_IF(nodes.size() > 1, "Not implemented"); + return gelu(nodes[0]); +} + Expr tanh(const std::vector<Expr>& nodes) { return Expression<TanhNodeOp>(nodes); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 78aed834..0205fe86 100755 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -17,6 +17,9 @@ Expr sigmoid(const std::vector<Expr>&); Expr swish(Expr a); Expr swish(const std::vector<Expr>&); +Expr gelu(Expr a); +Expr gelu(const std::vector<Expr>&); + Expr tanh(const std::vector<Expr>&); template <typename... Args> diff --git a/src/graph/node.cpp b/src/graph/node.cpp index c11531da..c11531da 100755..100644 --- a/src/graph/node.cpp +++ b/src/graph/node.cpp diff --git a/src/graph/node.h b/src/graph/node.h index 1397e74b..1397e74b 100755..100644 --- a/src/graph/node.h +++ b/src/graph/node.h diff --git a/src/graph/node_initializers.cpp b/src/graph/node_initializers.cpp index b2550c54..b2550c54 100755..100644 --- a/src/graph/node_initializers.cpp +++ b/src/graph/node_initializers.cpp diff --git a/src/graph/node_initializers.h b/src/graph/node_initializers.h index fbb07348..fbb07348 100755..100644 --- a/src/graph/node_initializers.h +++ b/src/graph/node_initializers.h diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 07245391..07245391 100755..100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 6dd90faf..39fb366f 100755 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -343,31 +343,52 @@ private: * in an expression graph. * * This node implements the activation function - * \f$ f(x) = x \cdot \sigma(x) \f$ + * \f$ f(x) = x \cdot \sigma(bx) \f$ * and its derivative - * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ . + * \f$ f^\prime(x) = bf(x) + \sigma(bx)(1 - bf(x)) \f$ . * */ struct SwishNodeOp : public UnaryNodeOp { - SwishNodeOp(Expr a) : UnaryNodeOp(a) {} + SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {} NodeOps forwardOps() override { using namespace functional; - return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))}; + return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; - // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) - return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)), + // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x))) + return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val(), // _2 := x - val_ // _3 := f(x) = x*sigma(x) + val_ // _3 := f(x) = x*sigmoid(b*x) ))}; } const std::string type() override { return "swish"; } + + virtual size_t hash() override { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + util::hash_combine(hash_, b_); + } + return hash_; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + Ptr<SwishNodeOp> cnode = std::dynamic_pointer_cast<SwishNodeOp>(node); + if(!cnode) + return false; + if(b_ != cnode->b_) + return false; + return true; + } + + float b_; }; struct SoftmaxNodeOp : public UnaryNodeOp { diff --git a/src/graph/parameters.h b/src/graph/parameters.h index 32f88a1e..32f88a1e 100755..100644 --- a/src/graph/parameters.h +++ b/src/graph/parameters.h |