diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-11-02 22:58:56 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-11-02 22:58:56 +0300 |
commit | 9d06d786ba568e74f54be98f5a23d9b53d87ed7b (patch) | |
tree | 8d82c2e2c33d9fa113d5bd9fd52bb8cd099db3bd /src/graph | |
parent | ca64c429e4aa4dcd49b68bab6a7d744fe06b44c2 (diff) |
farewell thrust
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/node_operators_binary.h | 24 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 14 |
2 files changed, 35 insertions, 3 deletions
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index f5fe6cf5..86fafdfb 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -193,6 +193,8 @@ struct AffineNodeOp : public NaryNodeOp { } NodeOps forwardOps() { + using namespace functional; + return { NodeOp(Prod(std::static_pointer_cast<BackendGPU>(getBackend()) ->getCublasHandle(), @@ -207,6 +209,8 @@ struct AffineNodeOp : public NaryNodeOp { } NodeOps backwardOps() { + using namespace functional; + // D is the adjoint, the matrix of derivatives // df/dA += D*B.T // df/dB += A.T*D @@ -405,10 +409,14 @@ struct ScalarProductNodeOp : public NaryNodeOp { } NodeOps forwardOps() { + using namespace functional; + return {NodeOp(Reduce(_1 * _2, val_, child(0)->val(), child(1)->val()))}; } NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1 * _2, child(0)->grad(), child(1)->val(), adj_)), NodeOp(Add(_1 * _2, child(1)->grad(), child(0)->val(), adj_))}; } @@ -435,11 +443,15 @@ struct PlusNodeOp : public ElementBinaryNodeOp { PlusNodeOp(Args... args) : ElementBinaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; + return { NodeOp(Element(_1 = _2 + _3, val_, child(0)->val(), child(1)->val()))}; } NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_)), NodeOp(Add(_1, child(1)->grad(), adj_))}; } @@ -452,11 +464,15 @@ struct MinusNodeOp : public ElementBinaryNodeOp { MinusNodeOp(Args... args) : ElementBinaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; + return { NodeOp(Element(_1 = _2 - _3, val_, child(0)->val(), child(1)->val()))}; } NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_)), NodeOp(Add(-_1, child(1)->grad(), adj_))}; } @@ -469,11 +485,15 @@ struct MultNodeOp : public ElementBinaryNodeOp { MultNodeOp(Args... args) : ElementBinaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; + return { NodeOp(Element(_1 = _2 * _3, val_, child(0)->val(), child(1)->val()))}; } NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1 * _2, child(0)->grad(), adj_, child(1)->val())), NodeOp(Add(_1 * _2, child(1)->grad(), adj_, child(0)->val()))}; } @@ -486,11 +506,15 @@ struct DivNodeOp : public ElementBinaryNodeOp { DivNodeOp(Args... args) : ElementBinaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; + return { NodeOp(Element(_1 = _2 / _3, val_, child(0)->val(), child(1)->val()))}; } NodeOps backwardOps() { + using namespace functional; + return { NodeOp(Add(_1 * 1.0f / _2, child(0)->grad(), adj_, child(1)->val())), NodeOp(Add(-_1 * _2 / (_3 * _3), diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 05294bee..9d5b8287 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -55,7 +55,10 @@ public: return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))}; } - NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; + } const std::string type() { return "scalar_add"; } }; @@ -392,9 +395,14 @@ struct SumNodeOp : public UnaryNodeOp { SumNodeOp(Expr a, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} - NodeOps forwardOps() { return {NodeOp(Reduce(_1, val_, child(0)->val()))}; } + NodeOps forwardOps() { + using namespace functional; + + return {NodeOp(Reduce(_1, val_, child(0)->val()))}; } - NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } template <class... Args> Shape newShape(Expr a, Args... args) { |