Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-28 23:17:34 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-28 23:17:34 +0300
commit352a437ab49ec00be944e11ed4bba0d52ac49931 (patch)
tree743dfa688033389f70eceadd055d77b1bf90418b /src/graph/node_operators_unary.h
parentb9197c2a5a4714c576f65a825553cfde429a09e1 (diff)
make cols operator non-constant
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h24
1 files changed, 14 insertions, 10 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 6749a585..dda4dd03 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -662,11 +662,11 @@ struct RowsNodeOp : public UnaryNodeOp {
}
template <class... Args>
- Shape newShape(Expr a, const std::vector<size_t>& indeces) {
+ Shape newShape(Expr a, const std::vector<size_t>& indices) {
Shape shape = a->shape();
ABORT_IF(shape.size() != 2,
"rows operator can only be used with 2-dimensional tensors");
- shape.set(0, indeces.size());
+ shape.set(0, indices.size());
return shape;
}
@@ -699,8 +699,10 @@ struct RowsNodeOp : public UnaryNodeOp {
};
struct ColsNodeOp : public UnaryNodeOp {
- ColsNodeOp(Expr a, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {}
+ ColsNodeOp(Expr a, const std::vector<size_t>& indices)
+ : UnaryNodeOp(a, newShape(a, indices)), indices_(indices) {
+ setMemoize(false);
+ }
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
@@ -713,9 +715,9 @@ struct ColsNodeOp : public UnaryNodeOp {
}
template <class... Args>
- Shape newShape(Expr a, const std::vector<size_t>& indeces) {
+ Shape newShape(Expr a, const std::vector<size_t>& indices) {
Shape shape = a->shape();
- shape.set(1, indeces.size());
+ shape.set(1, indices.size());
return shape;
}
@@ -748,8 +750,10 @@ struct ColsNodeOp : public UnaryNodeOp {
};
struct SelectNodeOp : public UnaryNodeOp {
- SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, axis, indeces)), indices_(indeces) {}
+ SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indices)
+ : UnaryNodeOp(a, newShape(a, axis, indices)), indices_(indices) {
+ setMemoize(false);
+ }
NodeOps forwardOps() {
return {NodeOp(
@@ -761,10 +765,10 @@ struct SelectNodeOp : public UnaryNodeOp {
Insert(child(0)->grad(), adj_, axis_, indices_, graph()->allocator()))};
}
- Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) {
+ Shape newShape(Expr a, int axis, const std::vector<size_t>& indices) {
Shape shape = a->shape();
axis_ = shape.axis(axis);
- shape.set(axis_, indeces.size());
+ shape.set(axis_, indices.size());
return shape;
}