diff options
author | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2019-09-05 06:21:40 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2019-09-05 06:21:40 +0300 |
commit | 075e9ce934df3de0547039c6078665f92e1ee091 (patch) | |
tree | ae4281d53a0145fc2ada80b33764050ce4d38b74 /src/graph | |
parent | 967acf9175ecfb4628928bd7dc27510205c63fe1 (diff) |
check shapes for concatenation
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/node_operators_binary.h | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 07245391..2cf49e36 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -931,12 +931,21 @@ struct ConcatenateNodeOp : public NaryNodeOp { } Shape newShape(const std::vector<Expr>& nodes, int ax) { - Shape shape = nodes.back()->shape(); + ABORT_IF(nodes.empty(), "No child nodes given"); + + Shape shape = nodes[0]->shape(); ax_ = shape.axis(ax); int sum = 0; - for(auto child : nodes) + auto checkShape = shape; + for(auto child : nodes) { + checkShape.set(ax_, child->shape()[ax_]); // don't abort on different sizes on axis dim. + ABORT_IF(checkShape != child->shape(), + "Child shapes {} and {} cannot be concatenated along axis {}", + shape, child->shape(), ax); + sum += child->shape()[ax_]; + } shape.set(ax_, sum); return shape; @@ -953,8 +962,7 @@ struct ConcatenateNodeOp : public NaryNodeOp { std::vector<Tensor> deconcatenees; for(size_t i = 0; i < children_.size(); ++i) { auto childPtr = child(i); - childPtr - ->set_zero_adjoint(); // @TODO: this is a hotfix, do this properly + childPtr->set_zero_adjoint(); // @TODO: this is a hotfix, do this properly deconcatenees.push_back(childPtr->grad()); } Deconcatenate(deconcatenees, adj_, ax_); |