diff options
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 08ecde46..173c8778 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -526,9 +526,9 @@ struct RowsNodeOp : public UnaryNodeOp { template <class... Args> Shape newShape(Expr a, const std::vector<size_t>& indeces) { Shape shape = a->shape(); + ABORT_IF(shape.size() != 2, + "rows operator can only be used with 2-dimensional tensors"); shape.set(0, indeces.size()); - shape.set(2, 1); - shape.set(3, 1); return shape; } |