Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-29 22:04:32 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-29 22:04:32 +0300
commit5ffe895d4c6f3561aa1eed0156a06f3333a10bea (patch)
tree7c70acafaf3ed8e309a0256de367b2286fbaca93 /src/graph/node_operators_unary.h
parentfe4a804d6692fa3bacd636e8b753f53911c59fbe (diff)
parentb3765f61bdbc30cb5f2a74ac7f882b8a0b9055ba (diff)
adjust rnn to work with new shape
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h72
1 files changed, 28 insertions, 44 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index a3f60366..08ecde46 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -141,17 +141,7 @@ struct TanhNodeOp : public NaryNodeOp {
: NaryNodeOp(nodes, keywords::shape = newShape(nodes)) {}
Shape newShape(const std::vector<Expr>& nodes) {
- Shape shape = nodes[0]->shape();
-
- for(int n = 1; n < nodes.size(); ++n) {
- Shape shapen = nodes[n]->shape();
- for(int i = 0; i < shapen.size(); ++i) {
- ABORT_IF(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1,
- "Shapes cannot be broadcasted");
- shape.set(i, std::max(shape[i], shapen[i]));
- }
- }
- return shape;
+ return Shape::broadcast(nodes);
}
NodeOps forwardOps() {
@@ -325,8 +315,7 @@ struct SumNodeOp : public UnaryNodeOp {
template <typename... Args>
SumNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...),
- ax_(keywords::Get(keywords::axis, -1, args...)) {}
+ : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {}
NodeOps forwardOps() { return {NodeOp(Reduce(_1, val_, child(0)->val()))}; }
@@ -334,15 +323,10 @@ struct SumNodeOp : public UnaryNodeOp {
template <class... Args>
Shape newShape(Expr a, Args... args) {
- int ax = keywords::Get(keywords::axis, -1, args...);
Shape shape = a->shape();
- if(ax != -1) {
- shape.set(ax, 1);
- } else {
- for(int i = 0; i < shape.size(); ++i) {
- shape.set(i, 1);
- }
- }
+ ax_ = shape.axis(keywords::Get(keywords::axis, -1, args...));
+
+ shape.set(ax_, 1);
return shape;
}
@@ -375,8 +359,7 @@ struct MeanNodeOp : public UnaryNodeOp {
template <typename... Args>
MeanNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...),
- ax_(keywords::Get(keywords::axis, -1, args...)) {}
+ : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {}
NodeOps forwardOps() {
int left = child(0)->shape().elements() / val_->shape().elements();
@@ -394,15 +377,9 @@ struct MeanNodeOp : public UnaryNodeOp {
template <class... Args>
Shape newShape(Expr a, Args... args) {
- int ax = keywords::Get(keywords::axis, -1, args...);
Shape shape = a->shape();
- if(ax != -1) {
- shape.set(ax, 1);
- } else {
- for(int i = 0; i < shape.size(); ++i) {
- shape.set(i, 1);
- }
- }
+ ax_ = shape.axis(keywords::Get(keywords::axis, -1, args...));
+ shape.set(ax_, 1);
return shape;
}
@@ -637,8 +614,7 @@ struct ColsNodeOp : public UnaryNodeOp {
struct SelectNodeOp : public UnaryNodeOp {
SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces)
: UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)),
- indeces_(indeces),
- axis_(axis) {}
+ indeces_(indeces) {}
NodeOps forwardOps() {
return {NodeOp(
@@ -652,7 +628,8 @@ struct SelectNodeOp : public UnaryNodeOp {
Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) {
Shape shape = a->shape();
- shape.set(axis, indeces.size());
+ axis_ = shape.axis(axis);
+ shape.set(axis_, indeces.size());
return shape;
}
@@ -707,8 +684,8 @@ struct TransposeNodeOp : public UnaryNodeOp {
Shape newShape(Expr a, Shape permute) {
Shape shape = a->shape();
- UTIL_THROW_IF2(shape.size() != permute.size(),
- "Shape and transpose axis have different number of dimensions");
+ ABORT_IF(shape.size() != permute.size(),
+ "Shape and transpose axis have different number of dimensions");
for(int i = 0; i < shape.size(); ++i)
shape.set(i, a->shape()[permute[i]]);
@@ -806,23 +783,27 @@ public:
}
};
-class TimestepNodeOp : public UnaryNodeOp {
+class StepNodeOp : public UnaryNodeOp {
private:
Expr stepNode_;
- size_t step_;
+ int step_;
+ int axis_;
public:
- TimestepNodeOp(Expr a, size_t step)
- : UnaryNodeOp(a, keywords::shape = newShape(a)),
+ StepNodeOp(Expr a, int step, int axis)
+ : UnaryNodeOp(a, keywords::shape = newShape(a, axis)),
stepNode_(a),
step_(step) {
Node::destroy_ = false;
}
- Shape newShape(Expr a) {
+ Shape newShape(Expr a, int axis) {
Shape outShape = a->shape();
- outShape.set(2, 1);
- outShape.set(3, 1);
+
+ axis_ = outShape.axis(axis);
+ for(int i = 0; i <= axis_; ++i)
+ outShape.set(i, 1);
+
return outShape;
}
@@ -862,6 +843,7 @@ public:
if(!hash_) {
hash_ = NaryNodeOp::hash();
boost::hash_combine(hash_, step_);
+ boost::hash_combine(hash_, axis_);
}
return hash_;
}
@@ -869,11 +851,13 @@ public:
virtual bool equal(Expr node) {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<TimestepNodeOp> cnode = std::dynamic_pointer_cast<TimestepNodeOp>(node);
+ Ptr<StepNodeOp> cnode = std::dynamic_pointer_cast<StepNodeOp>(node);
if(!cnode)
return false;
if(step_ != cnode->step_)
return false;
+ if(axis_ != cnode->axis_)
+ return false;
return true;
}
};