diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-19 15:09:38 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-19 15:09:38 +0300 |
commit | 4a5e3878e62a1661a37b16c7dba09b14391e5397 (patch) | |
tree | a68eebc6fc4d50f61f31850a493f41a99071cf0a /src/graph/node_operators_unary.h | |
parent | 113eab7d7955a7a1ab0f1dc6dd14371715bb4439 (diff) |
uses memory allocator for temp memory in kernels
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 51 |
1 files changed, 13 insertions, 38 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index bcc05d15..06932bea 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -611,11 +611,19 @@ struct SelectNodeOp : public UnaryNodeOp { indeces_(indeces), axis_(axis) {} NodeOps forwardOps() { - return {NodeOp(Select(val_, child(0)->val(), axis_, indeces_))}; + return {NodeOp(Select(graph()->allocator(), + val_, + child(0)->val(), + axis_, + indeces_))}; } NodeOps backwardOps() { - return {NodeOp(Insert(child(0)->grad(), adj_, axis_, indeces_))}; + return {NodeOp(Insert(graph()->allocator(), + child(0)->grad(), + adj_, + axis_, + indeces_))}; } Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) { @@ -657,42 +665,9 @@ struct SelectNodeOp : public UnaryNodeOp { }; struct TransposeNodeOp : public UnaryNodeOp { - template <typename... Args> - TransposeNodeOp(Expr a, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a), args...) {} - - NodeOps forwardOps() { - return {NodeOp(Transpose( - std::static_pointer_cast<BackendGPU>(getBackend())->getCublasHandle(), - val_, - child(0)->val()))}; - } - - NodeOps backwardOps() { - return {NodeOp(Transpose( - std::static_pointer_cast<BackendGPU>(getBackend())->getCublasHandle(), - child(0)->grad(), - adj_))}; - } - - template <class... Args> - Shape newShape(Expr a) { - Shape shape = a->shape(); - int temp = shape[0]; - shape.set(0, shape[1]); - shape.set(1, temp); - return shape; - } - - const std::string type() { return "transpose"; } - - const std::string color() { return "orange"; } -}; - -struct Transpose4DNodeOp : public UnaryNodeOp { Shape permute_; - Transpose4DNodeOp(Expr a, Shape permute) + TransposeNodeOp(Expr a, Shape permute) : UnaryNodeOp(a, keywords::shape = newShape(a, permute)), permute_{permute} {} @@ -727,7 +702,7 @@ struct Transpose4DNodeOp : public UnaryNodeOp { virtual bool equal(Expr node) { if(!NaryNodeOp::equal(node)) return false; - Ptr<Transpose4DNodeOp> cnode = std::dynamic_pointer_cast<Transpose4DNodeOp>(node); + Ptr<TransposeNodeOp> cnode = std::dynamic_pointer_cast<TransposeNodeOp>(node); if(!cnode) return false; if(permute_ != cnode->permute_) @@ -735,7 +710,7 @@ struct Transpose4DNodeOp : public UnaryNodeOp { return true; } - const std::string type() { return "transpose4d"; } + const std::string type() { return "transpose"; } const std::string color() { return "orange"; } }; |