diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-23 14:26:38 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-23 14:26:38 +0300 |
commit | 537fccc3e81831fb0aa8baaaaef2858c96c346ec (patch) | |
tree | 57a48f2e1ff7894289587255f503f3ab7744203f /src/graph/node_operators_unary.h | |
parent | 15947c6061c009679b0a00ca873d768c1159cd6f (diff) | |
parent | 9023667939b0fdd645f971cdeb0ab4e764b07057 (diff) |
Merge branch 'master' of https://github.com/marian-nmt/marian-dev into charS2S
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 95 |
1 files changed, 75 insertions, 20 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index faf21dee..8390a0c2 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -39,6 +39,25 @@ public: } const std::string type() { return "scalar_add"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, scalar_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ScalarAddNodeOp>(node); + if(!cnode) + return false; + if(scalar_ != cnode->scalar_) + return false; + return true; + } }; struct ScalarMultNodeOp : public UnaryNodeOp { @@ -61,6 +80,25 @@ public: } const std::string type() { return "scalar_add"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, scalar_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ScalarMultNodeOp>(node); + if(!cnode) + return false; + if(scalar_ != cnode->scalar_) + return false; + return true; + } }; struct LogitNodeOp : public UnaryNodeOp { @@ -256,6 +294,25 @@ struct PReLUNodeOp : public UnaryNodeOp { const std::string type() { return "PReLU"; } + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, alpha_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<PReLUNodeOp>(node); + if(!cnode) + return false; + if(alpha_ != cnode->alpha_) + return false; + return true; + } + private: float alpha_{0.01}; }; @@ -546,8 +603,6 @@ struct SqrtNodeOp : public UnaryNodeOp { }; struct SquareNodeOp : public UnaryNodeOp { - float epsilon_; - template <typename... Args> SquareNodeOp(Args... args) : UnaryNodeOp(args...) {} @@ -586,16 +641,16 @@ struct RowsNodeOp : public UnaryNodeOp { template <typename... Args> RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { // @TODO: solve this with a tensor! - return {NodeOp(CopyRows(val_, child(0)->val(), indeces_))}; + return {NodeOp(CopyRows(val_, child(0)->val(), indices_))}; } NodeOps backwardOps() { - return {NodeOp(PasteRows(child(0)->grad(), adj_, indeces_))}; + return {NodeOp(PasteRows(child(0)->grad(), adj_, indices_))}; } template <class... Args> @@ -614,7 +669,7 @@ struct RowsNodeOp : public UnaryNodeOp { virtual size_t hash() { if(!hash_) { size_t seed = NaryNodeOp::hash(); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -627,28 +682,28 @@ struct RowsNodeOp : public UnaryNodeOp { Ptr<RowsNodeOp> cnode = std::dynamic_pointer_cast<RowsNodeOp>(node); if(!cnode) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; }; struct ColsNodeOp : public UnaryNodeOp { template <typename... Args> ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { // @TODO: solve this with a tensor! - return {NodeOp(CopyCols(val_, child(0)->val(), indeces_))}; + return {NodeOp(CopyCols(val_, child(0)->val(), indices_))}; } NodeOps backwardOps() { - return {NodeOp(PasteCols(child(0)->grad(), adj_, indeces_))}; + return {NodeOp(PasteCols(child(0)->grad(), adj_, indices_))}; } template <class... Args> @@ -665,7 +720,7 @@ struct ColsNodeOp : public UnaryNodeOp { virtual size_t hash() { if(!hash_) { size_t seed = NaryNodeOp::hash(); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -678,27 +733,27 @@ struct ColsNodeOp : public UnaryNodeOp { Ptr<ColsNodeOp> cnode = std::dynamic_pointer_cast<ColsNodeOp>(node); if(!cnode) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; }; struct SelectNodeOp : public UnaryNodeOp { SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces) : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { return {NodeOp( - Select(graph()->allocator(), val_, child(0)->val(), axis_, indeces_))}; + Select(graph()->allocator(), val_, child(0)->val(), axis_, indices_))}; } NodeOps backwardOps() { return {NodeOp( - Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indeces_))}; + Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indices_))}; } Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) { @@ -716,7 +771,7 @@ struct SelectNodeOp : public UnaryNodeOp { if(!hash_) { size_t seed = NaryNodeOp::hash(); boost::hash_combine(seed, axis_); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -731,12 +786,12 @@ struct SelectNodeOp : public UnaryNodeOp { return false; if(axis_ != cnode->axis_) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; int axis_{0}; }; |