From b3765f61bdbc30cb5f2a74ac7f882b8a0b9055ba Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sun, 29 Oct 2017 16:50:02 +0100 Subject: replace TimeStepNode with more general StepNode --- src/graph/node_operators_unary.h | 58 +++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 18 deletions(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 7455d7f5..2c205f17 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -339,10 +339,9 @@ struct SumNodeOp : public UnaryNodeOp { if(ax != -1) { shape.set(ax, 1); } else { - shape.set(0, 1); - shape.set(1, 1); - shape.set(2, 1); - shape.set(3, 1); + for(int i = 0; i < shape.size(); ++i) { + shape.set(i, 1); + } } return shape; } @@ -400,10 +399,9 @@ struct MeanNodeOp : public UnaryNodeOp { if(ax != -1) { shape.set(ax, 1); } else { - shape.set(0, 1); - shape.set(1, 1); - shape.set(2, 1); - shape.set(3, 1); + for(int i = 0; i < shape.size(); ++i) { + shape.set(i, 1); + } } return shape; } @@ -707,9 +705,12 @@ struct TransposeNodeOp : public UnaryNodeOp { template Shape newShape(Expr a, Shape permute) { - Shape shape; + Shape shape = a->shape(); - for(int i = 0; i < 4; ++i) + UTIL_THROW_IF2(shape.size() != permute.size(), + "Shape and transpose axis have different number of dimensions"); + + for(int i = 0; i < shape.size(); ++i) shape.set(i, a->shape()[permute[i]]); return shape; @@ -805,23 +806,41 @@ public: } }; -class TimestepNodeOp : public UnaryNodeOp { +class StepNodeOp : public UnaryNodeOp { private: Expr stepNode_; size_t step_; + size_t axis_; public: - TimestepNodeOp(Expr a, size_t step) - : UnaryNodeOp(a, keywords::shape = newShape(a)), + StepNodeOp(Expr a, size_t step, size_t axis) + : UnaryNodeOp(a, keywords::shape = newShape(a, axis)), stepNode_(a), - step_(step) { + step_(step), + axis_(axis) { Node::destroy_ = false; } - Shape newShape(Expr a) { + Shape newShape(Expr a, size_t axis) { Shape outShape = a->shape(); - outShape.set(2, 1); - outShape.set(3, 1); + if(axis == 1) { + outShape.set(0, 1); + outShape.set(1, 1); + outShape.set(2, 1); + outShape.set(3, 1); + } + if(axis == 0) { + outShape.set(0, 1); + outShape.set(2, 1); + outShape.set(3, 1); + } + if(axis == 2) { + outShape.set(2, 1); + outShape.set(3, 1); + } + if(axis == 3) { + outShape.set(3, 1); + } return outShape; } @@ -861,6 +880,7 @@ public: if(!hash_) { hash_ = NaryNodeOp::hash(); boost::hash_combine(hash_, step_); + boost::hash_combine(hash_, axis_); } return hash_; } @@ -868,11 +888,13 @@ public: virtual bool equal(Expr node) { if(!NaryNodeOp::equal(node)) return false; - Ptr cnode = std::dynamic_pointer_cast(node); + Ptr cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; if(step_ != cnode->step_) return false; + if(axis_ != cnode->axis_) + return false; return true; } }; -- cgit v1.2.3