diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-11-02 22:58:56 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-11-02 22:58:56 +0300 |
commit | 9d06d786ba568e74f54be98f5a23d9b53d87ed7b (patch) | |
tree | 8d82c2e2c33d9fa113d5bd9fd52bb8cd099db3bd /src/graph/node_operators_unary.h | |
parent | ca64c429e4aa4dcd49b68bab6a7d744fe06b44c2 (diff) |
farewell thrust
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 05294bee..9d5b8287 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -55,7 +55,10 @@ public: return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))}; } - NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; + } const std::string type() { return "scalar_add"; } }; @@ -392,9 +395,14 @@ struct SumNodeOp : public UnaryNodeOp { SumNodeOp(Expr a, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} - NodeOps forwardOps() { return {NodeOp(Reduce(_1, val_, child(0)->val()))}; } + NodeOps forwardOps() { + using namespace functional; + + return {NodeOp(Reduce(_1, val_, child(0)->val()))}; } - NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } template <class... Args> Shape newShape(Expr a, Args... args) { |