Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-04-12 02:50:45 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-04-12 02:50:45 +0300
commit78a99473749ee038778f8b9ec37b16d0a62b86b7 (patch)
tree79d4664fc1d847262e060f1b74040e9ce66b88a1 /src/graph/node_operators_unary.h
parentbd06e1919ea908662ce6fe12d468be69c4f8c6f4 (diff)
working memoization
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h9
1 files changed, 6 insertions, 3 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index dc4015b2..273adf44 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -77,7 +77,7 @@ public:
return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))};
}
- const std::string type() { return "scalar_add"; }
+ const std::string type() { return "scalar_mult"; }
virtual size_t hash() {
if(!hash_) {
@@ -605,8 +605,11 @@ struct NegNodeOp : public UnaryNodeOp {
};
struct RowsNodeOp : public UnaryNodeOp {
- RowsNodeOp(Expr a, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {}
+ RowsNodeOp(Expr a, const std::vector<size_t>& indices)
+ : UnaryNodeOp(a, newShape(a, indices)), indices_(indices) {
+ // @TODO: fix this by using int32 tensor for indices
+ setMemoize(false);
+ }
NodeOps forwardOps() {
// @TODO: solve this with a tensor!