Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-05 23:24:47 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-05 23:24:47 +0300
commita4523f4cb3695fb2f1eb8c439258787fe2b82432 (patch)
treea2b19274c62abecd0ddf0627d188c60546a646d7 /src/graph/node_operators_unary.h
parentcb54a2573368aa4e56f4bf677a64681c4a16cbac (diff)
towards beam search with transformer model
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h51
1 files changed, 51 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 609f93fc..bcc05d15 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -605,6 +605,57 @@ struct ColsNodeOp : public UnaryNodeOp {
std::vector<size_t> indeces_;
};
+struct SelectNodeOp : public UnaryNodeOp {
+ SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces)
+ : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)),
+ indeces_(indeces), axis_(axis) {}
+
+ NodeOps forwardOps() {
+ return {NodeOp(Select(val_, child(0)->val(), axis_, indeces_))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(Insert(child(0)->grad(), adj_, axis_, indeces_))};
+ }
+
+ Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) {
+ Shape shape = a->shape();
+ shape.set(axis, indeces.size());
+ return shape;
+ }
+
+ const std::string type() { return "select"; }
+
+ const std::string color() { return "orange"; }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ size_t seed = NaryNodeOp::hash();
+ boost::hash_combine(seed, axis_);
+ for(auto i : indeces_)
+ boost::hash_combine(seed, i);
+ hash_ = seed;
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ Ptr<SelectNodeOp> cnode = std::dynamic_pointer_cast<SelectNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(axis_ != cnode->axis_)
+ return false;
+ if(indeces_ != cnode->indeces_)
+ return false;
+ return true;
+ }
+
+ std::vector<size_t> indeces_;
+ int axis_{0};
+};
+
struct TransposeNodeOp : public UnaryNodeOp {
template <typename... Args>
TransposeNodeOp(Expr a, Args... args)