diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-09-19 03:20:28 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-09-19 03:20:28 +0300 |
commit | aed9f48660bdf504a2b0db6ad6261f17e679c591 (patch) | |
tree | bc1196543ff98799f417521799e77f6f32f063f7 /src/graph/node_operators_unary.h | |
parent | 9f79c919dc886408d7180e207f60a2a3aa9af3d6 (diff) |
more complrex transformer
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 76b5f376..a269e501 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -598,6 +598,57 @@ struct TransposeNodeOp : public UnaryNodeOp { const std::string color() { return "orange"; } }; +struct Transpose4DNodeOp : public UnaryNodeOp { + Shape permute_; + + Transpose4DNodeOp(Expr a, Shape permute) + : UnaryNodeOp(a, keywords::shape = newShape(a, permute)), + permute_{permute} {} + + NodeOps forwardOps() { + return { NodeOp(Transpose4D(val_, child(0)->val(), permute_)) }; + } + + NodeOps backwardOps() { + return { NodeOp(Transpose4D(child(0)->grad(), adj_, permute_)) }; + } + + template <class... Args> + Shape newShape(Expr a, Shape permute) { + Shape shape; + + for(int i = 0; i < 4; ++i) + shape.set(i, a->shape()[permute[i]]); + + return shape; + } + + virtual size_t hash() { + if(!hash_) { + size_t seed = NaryNodeOp::hash(); + for(auto s : permute_) + boost::hash_combine(seed, s); + hash_ = seed; + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + Ptr<Transpose4DNodeOp> cnode = std::dynamic_pointer_cast<Transpose4DNodeOp>(node); + if(!cnode) + return false; + if(permute_ != cnode->permute_) + return false; + return true; + } + + const std::string type() { return "transpose4d"; } + + const std::string color() { return "orange"; } +}; + class ReshapeNodeOp : public UnaryNodeOp { private: Expr reshapee_; |