From 352a437ab49ec00be944e11ed4bba0d52ac49931 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Thu, 28 Jun 2018 13:17:34 -0700 Subject: make cols operator non-constant --- src/graph/node_operators_unary.h | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 6749a585..dda4dd03 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -662,11 +662,11 @@ struct RowsNodeOp : public UnaryNodeOp { } template - Shape newShape(Expr a, const std::vector& indeces) { + Shape newShape(Expr a, const std::vector& indices) { Shape shape = a->shape(); ABORT_IF(shape.size() != 2, "rows operator can only be used with 2-dimensional tensors"); - shape.set(0, indeces.size()); + shape.set(0, indices.size()); return shape; } @@ -699,8 +699,10 @@ struct RowsNodeOp : public UnaryNodeOp { }; struct ColsNodeOp : public UnaryNodeOp { - ColsNodeOp(Expr a, const std::vector& indeces) - : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {} + ColsNodeOp(Expr a, const std::vector& indices) + : UnaryNodeOp(a, newShape(a, indices)), indices_(indices) { + setMemoize(false); + } NodeOps forwardOps() { // @TODO: solve this with a tensor! @@ -713,9 +715,9 @@ struct ColsNodeOp : public UnaryNodeOp { } template - Shape newShape(Expr a, const std::vector& indeces) { + Shape newShape(Expr a, const std::vector& indices) { Shape shape = a->shape(); - shape.set(1, indeces.size()); + shape.set(1, indices.size()); return shape; } @@ -748,8 +750,10 @@ struct ColsNodeOp : public UnaryNodeOp { }; struct SelectNodeOp : public UnaryNodeOp { - SelectNodeOp(Expr a, int axis, const std::vector& indeces) - : UnaryNodeOp(a, newShape(a, axis, indeces)), indices_(indeces) {} + SelectNodeOp(Expr a, int axis, const std::vector& indices) + : UnaryNodeOp(a, newShape(a, axis, indices)), indices_(indices) { + setMemoize(false); + } NodeOps forwardOps() { return {NodeOp( @@ -761,10 +765,10 @@ struct SelectNodeOp : public UnaryNodeOp { Insert(child(0)->grad(), adj_, axis_, indices_, graph()->allocator()))}; } - Shape newShape(Expr a, int axis, const std::vector& indeces) { + Shape newShape(Expr a, int axis, const std::vector& indices) { Shape shape = a->shape(); axis_ = shape.axis(axis); - shape.set(axis_, indeces.size()); + shape.set(axis_, indices.size()); return shape; } -- cgit v1.2.3