diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-01-25 05:42:44 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-01-25 05:42:44 +0300 |
commit | 5adaf309658265be5c77cebcef4334769dd10903 (patch) | |
tree | 68f4c95670119bb60a278b0f09a25412e87935b1 /src/graph/node_operators_unary.h | |
parent | 622260e2006c9ba67d4f0532954a428278ad2e4b (diff) |
refactored layers
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 86 |
1 files changed, 65 insertions, 21 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 18687c0a..4a1945ef 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -34,10 +34,9 @@ struct LogitNodeOp : public UnaryNodeOp { NodeOps backwardOps() { return { - NodeOp(Element(_1 += _2 * _3 * (1.0f - _3), - children_[0]->grad(), - adj_, - val_)) + NodeOp(Add(_1 * _2 * (1.0f - _2), + children_[0]->grad(), + adj_, val_)) }; } @@ -46,26 +45,70 @@ struct LogitNodeOp : public UnaryNodeOp { } }; -struct TanhNodeOp : public UnaryNodeOp { - template <typename ...Args> - TanhNodeOp(Args ...args) - : UnaryNodeOp(args...) { } +struct TanhNodeOp : public NaryNodeOp { + TanhNodeOp(const std::vector<Expr>& nodes) + : NaryNodeOp(nodes, keywords::shape=newShape(nodes)) { } + + Shape newShape(const std::vector<Expr>& nodes) { + Shape shape = nodes[0]->shape(); + + for(int n = 1; n < nodes.size(); ++n) { + Shape shapen = nodes[n]->shape(); + for(int i = 0; i < shapen.size(); ++i) { + UTIL_THROW_IF2(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1, + "Shapes cannot be broadcasted"); + shape.set(i, std::max(shape[i], shapen[i])); + } + } + return shape; + } NodeOps forwardOps() { - return { - NodeOp(Element(_1 = Tanh(_2), - val_, - children_[0]->val())) - }; + switch (children_.size()) { + case 1: + return { NodeOp(Element(_1 = Tanh(_2), + val_, + children_[0]->val())) }; + case 2: + return { NodeOp(Element(_1 = Tanh(_2 + _3), + val_, + children_[0]->val(), + children_[1]->val())) }; + case 3: + return { NodeOp(Element(_1 = Tanh(_2 + _3 + _4), + val_, + children_[0]->val(), + children_[1]->val(), + children_[2]->val())) }; + default: + return { + NodeOp( + Element(_1 = _2 + _3 + _4, + val_, + children_[0]->val(), + children_[1]->val(), + children_[2]->val()); + for(int i = 3; i < children_.size(); ++i) + Element(_1 += _2, val_, children_[i]->val()); + Element(_1 = Tanh(_1), val_); + ) + }; + } } NodeOps backwardOps() { - return { - NodeOp(Element(_1 += _2 * (1.0f - (_3 * _3)), - children_[0]->grad(), - adj_, - val_)) - }; + NodeOps ops; + for(auto&& child : children_) { + ops.push_back( + NodeOp(Add(_1 * (1.0f - (_2 * _2)), + child->grad(), adj_, val_)) + ); + } + return ops; + } + + const std::string color() { + return "yellow"; } const std::string type() { @@ -103,8 +146,9 @@ struct ReLUNodeOp : public UnaryNodeOp { NodeOps backwardOps() { return { - NodeOp(Element(_1 += _2 * ReLUback(_3), - children_[0]->grad(), adj_, children_[0]->val())) + NodeOp(Add(_1 * ReLUback(_2), + children_[0]->grad(), + adj_, children_[0]->val())) }; } |