diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-23 14:26:38 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-23 14:26:38 +0300 |
commit | 537fccc3e81831fb0aa8baaaaef2858c96c346ec (patch) | |
tree | 57a48f2e1ff7894289587255f503f3ab7744203f /src/graph | |
parent | 15947c6061c009679b0a00ca873d768c1159cd6f (diff) | |
parent | 9023667939b0fdd645f971cdeb0ab4e764b07057 (diff) |
Merge branch 'master' of https://github.com/marian-nmt/marian-dev into charS2S
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/expression_operators.cu | 20 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 3 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 95 |
3 files changed, 98 insertions, 20 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index 4c4e0feb..d657ba74 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -121,6 +121,13 @@ Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax) { return Expression<ConcatenateNodeOp>(concats, ax); } +Expr repeat(Expr a, size_t repeats, keywords::axis_k ax) { + if(repeats == 1) + return a; + return concatenate(std::vector<Expr>(repeats, a), ax); +} + + Expr reshape(Expr a, Shape shape) { return Expression<ReshapeNodeOp>(a, shape); } @@ -137,6 +144,10 @@ Expr atleast_3d(Expr a, size_t dims) { return atleast_nd(a, 3); } +Expr atleast_4d(Expr a) { + return atleast_nd(a, 4); +} + Expr atleast_nd(Expr a, size_t dims) { if(a->shape().size() >= dims) return a; @@ -154,6 +165,15 @@ Expr flatten(Expr a) { return Expression<ReshapeNodeOp>(a, shape); } +Expr flatten_2d(Expr a) { + Shape shape = { + a->shape().elements() / a->shape()[-1], + a->shape()[-1] + }; + + return Expression<ReshapeNodeOp>(a, shape); +} + Expr rows(Expr a, const std::vector<size_t>& indices) { return Expression<RowsNodeOp>(a, indices); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 37e7c137..7302deab 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -73,15 +73,18 @@ Expr transpose(Expr a); Expr transpose(Expr a, const std::vector<int>& axes); Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax = 0); +Expr repeat(Expr a, size_t repeats, keywords::axis_k ax = 0); Expr reshape(Expr a, Shape shape); Expr atleast_1d(Expr a); Expr atleast_2d(Expr a); Expr atleast_3d(Expr a); +Expr atleast_4d(Expr a); Expr atleast_nd(Expr a, size_t dims); Expr flatten(Expr a); +Expr flatten_2d(Expr a); Expr rows(Expr a, const std::vector<size_t>& indices); Expr cols(Expr a, const std::vector<size_t>& indices); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index faf21dee..8390a0c2 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -39,6 +39,25 @@ public: } const std::string type() { return "scalar_add"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, scalar_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ScalarAddNodeOp>(node); + if(!cnode) + return false; + if(scalar_ != cnode->scalar_) + return false; + return true; + } }; struct ScalarMultNodeOp : public UnaryNodeOp { @@ -61,6 +80,25 @@ public: } const std::string type() { return "scalar_add"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, scalar_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ScalarMultNodeOp>(node); + if(!cnode) + return false; + if(scalar_ != cnode->scalar_) + return false; + return true; + } }; struct LogitNodeOp : public UnaryNodeOp { @@ -256,6 +294,25 @@ struct PReLUNodeOp : public UnaryNodeOp { const std::string type() { return "PReLU"; } + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, alpha_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<PReLUNodeOp>(node); + if(!cnode) + return false; + if(alpha_ != cnode->alpha_) + return false; + return true; + } + private: float alpha_{0.01}; }; @@ -546,8 +603,6 @@ struct SqrtNodeOp : public UnaryNodeOp { }; struct SquareNodeOp : public UnaryNodeOp { - float epsilon_; - template <typename... Args> SquareNodeOp(Args... args) : UnaryNodeOp(args...) {} @@ -586,16 +641,16 @@ struct RowsNodeOp : public UnaryNodeOp { template <typename... Args> RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { // @TODO: solve this with a tensor! - return {NodeOp(CopyRows(val_, child(0)->val(), indeces_))}; + return {NodeOp(CopyRows(val_, child(0)->val(), indices_))}; } NodeOps backwardOps() { - return {NodeOp(PasteRows(child(0)->grad(), adj_, indeces_))}; + return {NodeOp(PasteRows(child(0)->grad(), adj_, indices_))}; } template <class... Args> @@ -614,7 +669,7 @@ struct RowsNodeOp : public UnaryNodeOp { virtual size_t hash() { if(!hash_) { size_t seed = NaryNodeOp::hash(); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -627,28 +682,28 @@ struct RowsNodeOp : public UnaryNodeOp { Ptr<RowsNodeOp> cnode = std::dynamic_pointer_cast<RowsNodeOp>(node); if(!cnode) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; }; struct ColsNodeOp : public UnaryNodeOp { template <typename... Args> ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { // @TODO: solve this with a tensor! - return {NodeOp(CopyCols(val_, child(0)->val(), indeces_))}; + return {NodeOp(CopyCols(val_, child(0)->val(), indices_))}; } NodeOps backwardOps() { - return {NodeOp(PasteCols(child(0)->grad(), adj_, indeces_))}; + return {NodeOp(PasteCols(child(0)->grad(), adj_, indices_))}; } template <class... Args> @@ -665,7 +720,7 @@ struct ColsNodeOp : public UnaryNodeOp { virtual size_t hash() { if(!hash_) { size_t seed = NaryNodeOp::hash(); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -678,27 +733,27 @@ struct ColsNodeOp : public UnaryNodeOp { Ptr<ColsNodeOp> cnode = std::dynamic_pointer_cast<ColsNodeOp>(node); if(!cnode) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; }; struct SelectNodeOp : public UnaryNodeOp { SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces) : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)), - indeces_(indeces) {} + indices_(indeces) {} NodeOps forwardOps() { return {NodeOp( - Select(graph()->allocator(), val_, child(0)->val(), axis_, indeces_))}; + Select(graph()->allocator(), val_, child(0)->val(), axis_, indices_))}; } NodeOps backwardOps() { return {NodeOp( - Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indeces_))}; + Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indices_))}; } Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) { @@ -716,7 +771,7 @@ struct SelectNodeOp : public UnaryNodeOp { if(!hash_) { size_t seed = NaryNodeOp::hash(); boost::hash_combine(seed, axis_); - for(auto i : indeces_) + for(auto i : indices_) boost::hash_combine(seed, i); hash_ = seed; } @@ -731,12 +786,12 @@ struct SelectNodeOp : public UnaryNodeOp { return false; if(axis_ != cnode->axis_) return false; - if(indeces_ != cnode->indeces_) + if(indices_ != cnode->indices_) return false; return true; } - std::vector<size_t> indeces_; + std::vector<size_t> indices_; int axis_{0}; }; |