diff options
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index dc4015b2..273adf44 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -77,7 +77,7 @@ public: return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))}; } - const std::string type() { return "scalar_add"; } + const std::string type() { return "scalar_mult"; } virtual size_t hash() { if(!hash_) { @@ -605,8 +605,11 @@ struct NegNodeOp : public UnaryNodeOp { }; struct RowsNodeOp : public UnaryNodeOp { - RowsNodeOp(Expr a, const std::vector<size_t>& indeces) - : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {} + RowsNodeOp(Expr a, const std::vector<size_t>& indices) + : UnaryNodeOp(a, newShape(a, indices)), indices_(indices) { + // @TODO: fix this by using int32 tensor for indices + setMemoize(false); + } NodeOps forwardOps() { // @TODO: solve this with a tensor! |