Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-01-25 05:42:44 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-01-25 05:42:44 +0300
commit5adaf309658265be5c77cebcef4334769dd10903 (patch)
tree68f4c95670119bb60a278b0f09a25412e87935b1 /src/graph/node_operators_unary.h
parent622260e2006c9ba67d4f0532954a428278ad2e4b (diff)
refactored layers
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h86
1 files changed, 65 insertions, 21 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 18687c0a..4a1945ef 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -34,10 +34,9 @@ struct LogitNodeOp : public UnaryNodeOp {
NodeOps backwardOps() {
return {
- NodeOp(Element(_1 += _2 * _3 * (1.0f - _3),
- children_[0]->grad(),
- adj_,
- val_))
+ NodeOp(Add(_1 * _2 * (1.0f - _2),
+ children_[0]->grad(),
+ adj_, val_))
};
}
@@ -46,26 +45,70 @@ struct LogitNodeOp : public UnaryNodeOp {
}
};
-struct TanhNodeOp : public UnaryNodeOp {
- template <typename ...Args>
- TanhNodeOp(Args ...args)
- : UnaryNodeOp(args...) { }
+struct TanhNodeOp : public NaryNodeOp {
+ TanhNodeOp(const std::vector<Expr>& nodes)
+ : NaryNodeOp(nodes, keywords::shape=newShape(nodes)) { }
+
+ Shape newShape(const std::vector<Expr>& nodes) {
+ Shape shape = nodes[0]->shape();
+
+ for(int n = 1; n < nodes.size(); ++n) {
+ Shape shapen = nodes[n]->shape();
+ for(int i = 0; i < shapen.size(); ++i) {
+ UTIL_THROW_IF2(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1,
+ "Shapes cannot be broadcasted");
+ shape.set(i, std::max(shape[i], shapen[i]));
+ }
+ }
+ return shape;
+ }
NodeOps forwardOps() {
- return {
- NodeOp(Element(_1 = Tanh(_2),
- val_,
- children_[0]->val()))
- };
+ switch (children_.size()) {
+ case 1:
+ return { NodeOp(Element(_1 = Tanh(_2),
+ val_,
+ children_[0]->val())) };
+ case 2:
+ return { NodeOp(Element(_1 = Tanh(_2 + _3),
+ val_,
+ children_[0]->val(),
+ children_[1]->val())) };
+ case 3:
+ return { NodeOp(Element(_1 = Tanh(_2 + _3 + _4),
+ val_,
+ children_[0]->val(),
+ children_[1]->val(),
+ children_[2]->val())) };
+ default:
+ return {
+ NodeOp(
+ Element(_1 = _2 + _3 + _4,
+ val_,
+ children_[0]->val(),
+ children_[1]->val(),
+ children_[2]->val());
+ for(int i = 3; i < children_.size(); ++i)
+ Element(_1 += _2, val_, children_[i]->val());
+ Element(_1 = Tanh(_1), val_);
+ )
+ };
+ }
}
NodeOps backwardOps() {
- return {
- NodeOp(Element(_1 += _2 * (1.0f - (_3 * _3)),
- children_[0]->grad(),
- adj_,
- val_))
- };
+ NodeOps ops;
+ for(auto&& child : children_) {
+ ops.push_back(
+ NodeOp(Add(_1 * (1.0f - (_2 * _2)),
+ child->grad(), adj_, val_))
+ );
+ }
+ return ops;
+ }
+
+ const std::string color() {
+ return "yellow";
}
const std::string type() {
@@ -103,8 +146,9 @@ struct ReLUNodeOp : public UnaryNodeOp {
NodeOps backwardOps() {
return {
- NodeOp(Element(_1 += _2 * ReLUback(_3),
- children_[0]->grad(), adj_, children_[0]->val()))
+ NodeOp(Add(_1 * ReLUback(_2),
+ children_[0]->grad(),
+ adj_, children_[0]->val()))
};
}