diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-05 23:24:47 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-05 23:24:47 +0300 |
commit | a4523f4cb3695fb2f1eb8c439258787fe2b82432 (patch) | |
tree | a2b19274c62abecd0ddf0627d188c60546a646d7 /src/graph/node_operators_unary.h | |
parent | cb54a2573368aa4e56f4bf677a64681c4a16cbac (diff) |
towards beam search with transformer model
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 609f93fc..bcc05d15 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -605,6 +605,57 @@ struct ColsNodeOp : public UnaryNodeOp { std::vector<size_t> indeces_; }; +struct SelectNodeOp : public UnaryNodeOp { + SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces) + : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)), + indeces_(indeces), axis_(axis) {} + + NodeOps forwardOps() { + return {NodeOp(Select(val_, child(0)->val(), axis_, indeces_))}; + } + + NodeOps backwardOps() { + return {NodeOp(Insert(child(0)->grad(), adj_, axis_, indeces_))}; + } + + Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) { + Shape shape = a->shape(); + shape.set(axis, indeces.size()); + return shape; + } + + const std::string type() { return "select"; } + + const std::string color() { return "orange"; } + + virtual size_t hash() { + if(!hash_) { + size_t seed = NaryNodeOp::hash(); + boost::hash_combine(seed, axis_); + for(auto i : indeces_) + boost::hash_combine(seed, i); + hash_ = seed; + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + Ptr<SelectNodeOp> cnode = std::dynamic_pointer_cast<SelectNodeOp>(node); + if(!cnode) + return false; + if(axis_ != cnode->axis_) + return false; + if(indeces_ != cnode->indeces_) + return false; + return true; + } + + std::vector<size_t> indeces_; + int axis_{0}; +}; + struct TransposeNodeOp : public UnaryNodeOp { template <typename... Args> TransposeNodeOp(Expr a, Args... args) |