Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-10-19 15:29:07 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-10-19 15:29:07 +0300
commit2ad0ca3ceeae719ecab208bec4da71a9068a5f7a (patch)
tree295601fcd1469054de5c975ae05a4319fcbbd5d1 /src/graph/node_operators_unary.h
parent10cf2e4ba639a0728b32f3a81adb72777dab341f (diff)
parent4a5e3878e62a1661a37b16c7dba09b14391e5397 (diff)
Merge with 'master' branch
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h51
1 files changed, 13 insertions, 38 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index e68231e3..af8502a6 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -621,11 +621,19 @@ struct SelectNodeOp : public UnaryNodeOp {
indeces_(indeces), axis_(axis) {}
NodeOps forwardOps() {
- return {NodeOp(Select(val_, child(0)->val(), axis_, indeces_))};
+ return {NodeOp(Select(graph()->allocator(),
+ val_,
+ child(0)->val(),
+ axis_,
+ indeces_))};
}
NodeOps backwardOps() {
- return {NodeOp(Insert(child(0)->grad(), adj_, axis_, indeces_))};
+ return {NodeOp(Insert(graph()->allocator(),
+ child(0)->grad(),
+ adj_,
+ axis_,
+ indeces_))};
}
Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) {
@@ -667,42 +675,9 @@ struct SelectNodeOp : public UnaryNodeOp {
};
struct TransposeNodeOp : public UnaryNodeOp {
- template <typename... Args>
- TransposeNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, keywords::shape = newShape(a), args...) {}
-
- NodeOps forwardOps() {
- return {NodeOp(Transpose(
- std::static_pointer_cast<BackendGPU>(getBackend())->getCublasHandle(),
- val_,
- child(0)->val()))};
- }
-
- NodeOps backwardOps() {
- return {NodeOp(Transpose(
- std::static_pointer_cast<BackendGPU>(getBackend())->getCublasHandle(),
- child(0)->grad(),
- adj_))};
- }
-
- template <class... Args>
- Shape newShape(Expr a) {
- Shape shape = a->shape();
- int temp = shape[0];
- shape.set(0, shape[1]);
- shape.set(1, temp);
- return shape;
- }
-
- const std::string type() { return "transpose"; }
-
- const std::string color() { return "orange"; }
-};
-
-struct Transpose4DNodeOp : public UnaryNodeOp {
Shape permute_;
- Transpose4DNodeOp(Expr a, Shape permute)
+ TransposeNodeOp(Expr a, Shape permute)
: UnaryNodeOp(a, keywords::shape = newShape(a, permute)),
permute_{permute} {}
@@ -737,7 +712,7 @@ struct Transpose4DNodeOp : public UnaryNodeOp {
virtual bool equal(Expr node) {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<Transpose4DNodeOp> cnode = std::dynamic_pointer_cast<Transpose4DNodeOp>(node);
+ Ptr<TransposeNodeOp> cnode = std::dynamic_pointer_cast<TransposeNodeOp>(node);
if(!cnode)
return false;
if(permute_ != cnode->permute_)
@@ -745,7 +720,7 @@ struct Transpose4DNodeOp : public UnaryNodeOp {
return true;
}
- const std::string type() { return "transpose4d"; }
+ const std::string type() { return "transpose"; }
const std::string color() { return "orange"; }
};