diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-29 22:04:32 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-29 22:04:32 +0300 |
commit | 5ffe895d4c6f3561aa1eed0156a06f3333a10bea (patch) | |
tree | 7c70acafaf3ed8e309a0256de367b2286fbaca93 /src/graph/node_operators_unary.h | |
parent | fe4a804d6692fa3bacd636e8b753f53911c59fbe (diff) | |
parent | b3765f61bdbc30cb5f2a74ac7f882b8a0b9055ba (diff) |
adjust rnn to work with new shape
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 72 |
1 files changed, 28 insertions, 44 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index a3f60366..08ecde46 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -141,17 +141,7 @@ struct TanhNodeOp : public NaryNodeOp { : NaryNodeOp(nodes, keywords::shape = newShape(nodes)) {} Shape newShape(const std::vector<Expr>& nodes) { - Shape shape = nodes[0]->shape(); - - for(int n = 1; n < nodes.size(); ++n) { - Shape shapen = nodes[n]->shape(); - for(int i = 0; i < shapen.size(); ++i) { - ABORT_IF(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1, - "Shapes cannot be broadcasted"); - shape.set(i, std::max(shape[i], shapen[i])); - } - } - return shape; + return Shape::broadcast(nodes); } NodeOps forwardOps() { @@ -325,8 +315,7 @@ struct SumNodeOp : public UnaryNodeOp { template <typename... Args> SumNodeOp(Expr a, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...), - ax_(keywords::Get(keywords::axis, -1, args...)) {} + : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} NodeOps forwardOps() { return {NodeOp(Reduce(_1, val_, child(0)->val()))}; } @@ -334,15 +323,10 @@ struct SumNodeOp : public UnaryNodeOp { template <class... Args> Shape newShape(Expr a, Args... args) { - int ax = keywords::Get(keywords::axis, -1, args...); Shape shape = a->shape(); - if(ax != -1) { - shape.set(ax, 1); - } else { - for(int i = 0; i < shape.size(); ++i) { - shape.set(i, 1); - } - } + ax_ = shape.axis(keywords::Get(keywords::axis, -1, args...)); + + shape.set(ax_, 1); return shape; } @@ -375,8 +359,7 @@ struct MeanNodeOp : public UnaryNodeOp { template <typename... Args> MeanNodeOp(Expr a, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...), - ax_(keywords::Get(keywords::axis, -1, args...)) {} + : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} NodeOps forwardOps() { int left = child(0)->shape().elements() / val_->shape().elements(); @@ -394,15 +377,9 @@ struct MeanNodeOp : public UnaryNodeOp { template <class... Args> Shape newShape(Expr a, Args... args) { - int ax = keywords::Get(keywords::axis, -1, args...); Shape shape = a->shape(); - if(ax != -1) { - shape.set(ax, 1); - } else { - for(int i = 0; i < shape.size(); ++i) { - shape.set(i, 1); - } - } + ax_ = shape.axis(keywords::Get(keywords::axis, -1, args...)); + shape.set(ax_, 1); return shape; } @@ -637,8 +614,7 @@ struct ColsNodeOp : public UnaryNodeOp { struct SelectNodeOp : public UnaryNodeOp { SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces) : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)), - indeces_(indeces), - axis_(axis) {} + indeces_(indeces) {} NodeOps forwardOps() { return {NodeOp( @@ -652,7 +628,8 @@ struct SelectNodeOp : public UnaryNodeOp { Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) { Shape shape = a->shape(); - shape.set(axis, indeces.size()); + axis_ = shape.axis(axis); + shape.set(axis_, indeces.size()); return shape; } @@ -707,8 +684,8 @@ struct TransposeNodeOp : public UnaryNodeOp { Shape newShape(Expr a, Shape permute) { Shape shape = a->shape(); - UTIL_THROW_IF2(shape.size() != permute.size(), - "Shape and transpose axis have different number of dimensions"); + ABORT_IF(shape.size() != permute.size(), + "Shape and transpose axis have different number of dimensions"); for(int i = 0; i < shape.size(); ++i) shape.set(i, a->shape()[permute[i]]); @@ -806,23 +783,27 @@ public: } }; -class TimestepNodeOp : public UnaryNodeOp { +class StepNodeOp : public UnaryNodeOp { private: Expr stepNode_; - size_t step_; + int step_; + int axis_; public: - TimestepNodeOp(Expr a, size_t step) - : UnaryNodeOp(a, keywords::shape = newShape(a)), + StepNodeOp(Expr a, int step, int axis) + : UnaryNodeOp(a, keywords::shape = newShape(a, axis)), stepNode_(a), step_(step) { Node::destroy_ = false; } - Shape newShape(Expr a) { + Shape newShape(Expr a, int axis) { Shape outShape = a->shape(); - outShape.set(2, 1); - outShape.set(3, 1); + + axis_ = outShape.axis(axis); + for(int i = 0; i <= axis_; ++i) + outShape.set(i, 1); + return outShape; } @@ -862,6 +843,7 @@ public: if(!hash_) { hash_ = NaryNodeOp::hash(); boost::hash_combine(hash_, step_); + boost::hash_combine(hash_, axis_); } return hash_; } @@ -869,11 +851,13 @@ public: virtual bool equal(Expr node) { if(!NaryNodeOp::equal(node)) return false; - Ptr<TimestepNodeOp> cnode = std::dynamic_pointer_cast<TimestepNodeOp>(node); + Ptr<StepNodeOp> cnode = std::dynamic_pointer_cast<StepNodeOp>(node); if(!cnode) return false; if(step_ != cnode->step_) return false; + if(axis_ != cnode->axis_) + return false; return true; } }; |