Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorFrank Seide <fseide@microsoft.com>2019-02-07 19:35:21 +0300
committerFrank Seide <fseide@microsoft.com>2019-02-07 19:35:21 +0300
commit5f498c9c658c65f89b85df2ef4ef6eaaa3f2ebf5 (patch)
treef3b627c7a5f7cb089b8283fd29c308915c5e6047 /src/graph
parent46e9565c9a06e54f4f05f50f722da726ae437df0 (diff)
parent9b54e7f1caa86b16da84638a20566e8d5451c170 (diff)
merged from fseide/commentbeamsearch
Diffstat (limited to 'src/graph')
-rw-r--r--[-rwxr-xr-x]src/graph/chainable.h0
-rw-r--r--[-rwxr-xr-x]src/graph/expression_graph.cpp0
-rw-r--r--[-rwxr-xr-x]src/graph/expression_graph.h0
-rwxr-xr-xsrc/graph/expression_operators.cpp11
-rwxr-xr-xsrc/graph/expression_operators.h3
-rw-r--r--[-rwxr-xr-x]src/graph/node.cpp0
-rw-r--r--[-rwxr-xr-x]src/graph/node.h0
-rw-r--r--[-rwxr-xr-x]src/graph/node_initializers.cpp0
-rw-r--r--[-rwxr-xr-x]src/graph/node_initializers.h0
-rw-r--r--[-rwxr-xr-x]src/graph/node_operators_binary.h0
-rwxr-xr-xsrc/graph/node_operators_unary.h35
-rw-r--r--[-rwxr-xr-x]src/graph/parameters.h0
12 files changed, 41 insertions, 8 deletions
diff --git a/src/graph/chainable.h b/src/graph/chainable.h
index 2679843e..2679843e 100755..100644
--- a/src/graph/chainable.h
+++ b/src/graph/chainable.h
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp
index b3c237d1..b3c237d1 100755..100644
--- a/src/graph/expression_graph.cpp
+++ b/src/graph/expression_graph.cpp
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index dfd00a9a..dfd00a9a 100755..100644
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 866c4b2c..5c37beef 100755
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -51,6 +51,10 @@ Expr swish(Expr a) {
return Expression<SwishNodeOp>(a);
}
+Expr gelu(Expr a) {
+ return Expression<SwishNodeOp>(a, 1.702f);
+}
+
Expr operator-(Expr a) {
return Expression<NegNodeOp>(a);
};
@@ -532,7 +536,7 @@ Expr swapAxes(Expr x, int axis1, int axis2)
return x;
// TODO: This is code dup from transpose(x). Implement transpose(x) as swapAxes(x, 0, 1)
std::vector<int> axes(x->shape().size());
- for (int i = 0; i < axes.size(); ++i)
+ for (int i = 0; i < axes.size(); ++i) // @TODO: use std::iota()
axes[i] = i;
std::swap(axes[axis1], axes[axis2]);
return transpose(x, axes);
@@ -552,6 +556,11 @@ Expr swish(const std::vector<Expr>& nodes) {
return swish(nodes[0]);
}
+Expr gelu(const std::vector<Expr>& nodes) {
+ ABORT_IF(nodes.size() > 1, "Not implemented");
+ return gelu(nodes[0]);
+}
+
Expr tanh(const std::vector<Expr>& nodes) {
return Expression<TanhNodeOp>(nodes);
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 78aed834..0205fe86 100755
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -17,6 +17,9 @@ Expr sigmoid(const std::vector<Expr>&);
Expr swish(Expr a);
Expr swish(const std::vector<Expr>&);
+Expr gelu(Expr a);
+Expr gelu(const std::vector<Expr>&);
+
Expr tanh(const std::vector<Expr>&);
template <typename... Args>
diff --git a/src/graph/node.cpp b/src/graph/node.cpp
index c11531da..c11531da 100755..100644
--- a/src/graph/node.cpp
+++ b/src/graph/node.cpp
diff --git a/src/graph/node.h b/src/graph/node.h
index 1397e74b..1397e74b 100755..100644
--- a/src/graph/node.h
+++ b/src/graph/node.h
diff --git a/src/graph/node_initializers.cpp b/src/graph/node_initializers.cpp
index b2550c54..b2550c54 100755..100644
--- a/src/graph/node_initializers.cpp
+++ b/src/graph/node_initializers.cpp
diff --git a/src/graph/node_initializers.h b/src/graph/node_initializers.h
index fbb07348..fbb07348 100755..100644
--- a/src/graph/node_initializers.h
+++ b/src/graph/node_initializers.h
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 07245391..07245391 100755..100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 6dd90faf..39fb366f 100755
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -343,31 +343,52 @@ private:
* in an expression graph.
*
* This node implements the activation function
- * \f$ f(x) = x \cdot \sigma(x) \f$
+ * \f$ f(x) = x \cdot \sigma(bx) \f$
* and its derivative
- * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ .
+ * \f$ f^\prime(x) = bf(x) + \sigma(bx)(1 - bf(x)) \f$ .
*
*/
struct SwishNodeOp : public UnaryNodeOp {
- SwishNodeOp(Expr a) : UnaryNodeOp(a) {}
+ SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {}
NodeOps forwardOps() override {
using namespace functional;
- return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))};
+ return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
- // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) )
- return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)),
+ // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x)))
+ return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))),
child(0)->grad(), // dJ/dx
adj_, // _1 := dJ/df
child(0)->val(), // _2 := x
- val_ // _3 := f(x) = x*sigma(x)
+ val_ // _3 := f(x) = x*sigmoid(b*x)
))};
}
const std::string type() override { return "swish"; }
+
+ virtual size_t hash() override {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ util::hash_combine(hash_, b_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ Ptr<SwishNodeOp> cnode = std::dynamic_pointer_cast<SwishNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(b_ != cnode->b_)
+ return false;
+ return true;
+ }
+
+ float b_;
};
struct SoftmaxNodeOp : public UnaryNodeOp {
diff --git a/src/graph/parameters.h b/src/graph/parameters.h
index 32f88a1e..32f88a1e 100755..100644
--- a/src/graph/parameters.h
+++ b/src/graph/parameters.h