Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2019-09-05 06:21:40 +0300
committerMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2019-09-05 06:21:40 +0300
commit075e9ce934df3de0547039c6078665f92e1ee091 (patch)
treeae4281d53a0145fc2ada80b33764050ce4d38b74 /src/graph
parent967acf9175ecfb4628928bd7dc27510205c63fe1 (diff)
check shapes for concatenation
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/node_operators_binary.h16
1 files changed, 12 insertions, 4 deletions
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 07245391..2cf49e36 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -931,12 +931,21 @@ struct ConcatenateNodeOp : public NaryNodeOp {
}
Shape newShape(const std::vector<Expr>& nodes, int ax) {
- Shape shape = nodes.back()->shape();
+ ABORT_IF(nodes.empty(), "No child nodes given");
+
+ Shape shape = nodes[0]->shape();
ax_ = shape.axis(ax);
int sum = 0;
- for(auto child : nodes)
+ auto checkShape = shape;
+ for(auto child : nodes) {
+ checkShape.set(ax_, child->shape()[ax_]); // don't abort on different sizes on axis dim.
+ ABORT_IF(checkShape != child->shape(),
+ "Child shapes {} and {} cannot be concatenated along axis {}",
+ shape, child->shape(), ax);
+
sum += child->shape()[ax_];
+ }
shape.set(ax_, sum);
return shape;
@@ -953,8 +962,7 @@ struct ConcatenateNodeOp : public NaryNodeOp {
std::vector<Tensor> deconcatenees;
for(size_t i = 0; i < children_.size(); ++i) {
auto childPtr = child(i);
- childPtr
- ->set_zero_adjoint(); // @TODO: this is a hotfix, do this properly
+ childPtr->set_zero_adjoint(); // @TODO: this is a hotfix, do this properly
deconcatenees.push_back(childPtr->grad());
}
Deconcatenate(deconcatenees, adj_, ax_);