diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-06-28 23:17:34 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-06-28 23:17:34 +0300 |
commit | 352a437ab49ec00be944e11ed4bba0d52ac49931 (patch) | |
tree | 743dfa688033389f70eceadd055d77b1bf90418b /src/graph/node_operators_unary.h | |
parent | b9197c2a5a4714c576f65a825553cfde429a09e1 (diff) |
make cols operator non-constant
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 6749a585..dda4dd03 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -662,11 +662,11 @@ struct RowsNodeOp : public UnaryNodeOp { } template <class... Args> - Shape newShape(Expr a, const std::vector<size_t>& indeces) { + Shape newShape(Expr a, const std::vector<size_t>& indices) { Shape shape = a->shape(); ABORT_IF(shape.size() != 2, "rows operator can only be used with 2-dimensional tensors"); - shape.set(0, indeces.size()); + shape.set(0, indices.size()); return shape; } @@ -699,8 +699,10 @@ struct RowsNodeOp : public UnaryNodeOp { }; struct ColsNodeOp : public UnaryNodeOp { - ColsNodeOp(Expr a, const std::vector<size_t>& indeces) - : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {} + ColsNodeOp(Expr a, const std::vector<size_t>& indices) + : UnaryNodeOp(a, newShape(a, indices)), indices_(indices) { + setMemoize(false); + } NodeOps forwardOps() { // @TODO: solve this with a tensor! @@ -713,9 +715,9 @@ struct ColsNodeOp : public UnaryNodeOp { } template <class... Args> - Shape newShape(Expr a, const std::vector<size_t>& indeces) { + Shape newShape(Expr a, const std::vector<size_t>& indices) { Shape shape = a->shape(); - shape.set(1, indeces.size()); + shape.set(1, indices.size()); return shape; } @@ -748,8 +750,10 @@ struct ColsNodeOp : public UnaryNodeOp { }; struct SelectNodeOp : public UnaryNodeOp { - SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces) - : UnaryNodeOp(a, newShape(a, axis, indeces)), indices_(indeces) {} + SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indices) + : UnaryNodeOp(a, newShape(a, axis, indices)), indices_(indices) { + setMemoize(false); + } NodeOps forwardOps() { return {NodeOp( @@ -761,10 +765,10 @@ struct SelectNodeOp : public UnaryNodeOp { Insert(child(0)->grad(), adj_, axis_, indices_, graph()->allocator()))}; } - Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) { + Shape newShape(Expr a, int axis, const std::vector<size_t>& indices) { Shape shape = a->shape(); axis_ = shape.axis(axis); - shape.set(axis_, indeces.size()); + shape.set(axis_, indices.size()); return shape; } |