diff options
author | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2017-10-19 15:29:07 +0300 |
---|---|---|
committer | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2017-10-19 15:29:07 +0300 |
commit | 2ad0ca3ceeae719ecab208bec4da71a9068a5f7a (patch) | |
tree | 295601fcd1469054de5c975ae05a4319fcbbd5d1 /src/graph/node_operators_unary.h | |
parent | 10cf2e4ba639a0728b32f3a81adb72777dab341f (diff) | |
parent | 4a5e3878e62a1661a37b16c7dba09b14391e5397 (diff) |
Merge with 'master' branch
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 51 |
1 files changed, 13 insertions, 38 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index e68231e3..af8502a6 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -621,11 +621,19 @@ struct SelectNodeOp : public UnaryNodeOp { indeces_(indeces), axis_(axis) {} NodeOps forwardOps() { - return {NodeOp(Select(val_, child(0)->val(), axis_, indeces_))}; + return {NodeOp(Select(graph()->allocator(), + val_, + child(0)->val(), + axis_, + indeces_))}; } NodeOps backwardOps() { - return {NodeOp(Insert(child(0)->grad(), adj_, axis_, indeces_))}; + return {NodeOp(Insert(graph()->allocator(), + child(0)->grad(), + adj_, + axis_, + indeces_))}; } Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) { @@ -667,42 +675,9 @@ struct SelectNodeOp : public UnaryNodeOp { }; struct TransposeNodeOp : public UnaryNodeOp { - template <typename... Args> - TransposeNodeOp(Expr a, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a), args...) {} - - NodeOps forwardOps() { - return {NodeOp(Transpose( - std::static_pointer_cast<BackendGPU>(getBackend())->getCublasHandle(), - val_, - child(0)->val()))}; - } - - NodeOps backwardOps() { - return {NodeOp(Transpose( - std::static_pointer_cast<BackendGPU>(getBackend())->getCublasHandle(), - child(0)->grad(), - adj_))}; - } - - template <class... Args> - Shape newShape(Expr a) { - Shape shape = a->shape(); - int temp = shape[0]; - shape.set(0, shape[1]); - shape.set(1, temp); - return shape; - } - - const std::string type() { return "transpose"; } - - const std::string color() { return "orange"; } -}; - -struct Transpose4DNodeOp : public UnaryNodeOp { Shape permute_; - Transpose4DNodeOp(Expr a, Shape permute) + TransposeNodeOp(Expr a, Shape permute) : UnaryNodeOp(a, keywords::shape = newShape(a, permute)), permute_{permute} {} @@ -737,7 +712,7 @@ struct Transpose4DNodeOp : public UnaryNodeOp { virtual bool equal(Expr node) { if(!NaryNodeOp::equal(node)) return false; - Ptr<Transpose4DNodeOp> cnode = std::dynamic_pointer_cast<Transpose4DNodeOp>(node); + Ptr<TransposeNodeOp> cnode = std::dynamic_pointer_cast<TransposeNodeOp>(node); if(!cnode) return false; if(permute_ != cnode->permute_) @@ -745,7 +720,7 @@ struct Transpose4DNodeOp : public UnaryNodeOp { return true; } - const std::string type() { return "transpose4d"; } + const std::string type() { return "transpose"; } const std::string color() { return "orange"; } }; |