diff options
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 05294bee..9d5b8287 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -55,7 +55,10 @@ public: return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))}; } - NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; + } const std::string type() { return "scalar_add"; } }; @@ -392,9 +395,14 @@ struct SumNodeOp : public UnaryNodeOp { SumNodeOp(Expr a, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} - NodeOps forwardOps() { return {NodeOp(Reduce(_1, val_, child(0)->val()))}; } + NodeOps forwardOps() { + using namespace functional; + + return {NodeOp(Reduce(_1, val_, child(0)->val()))}; } - NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } template <class... Args> Shape newShape(Expr a, Args... args) { |