Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-04-15 11:33:29 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-04-15 11:33:29 +0300
commita124126c85571b02b49e379011e3dbfedffaf8a1 (patch)
tree17565a3c0b5d3bc90d2503c705888eceaa2595d3 /src/graph/node_operators_unary.h
parentb7a5fe5cb921b82cc18b4b9ffc6d605548b9d617 (diff)
cleaner batching, operator for subselecting columns
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h55
1 files changed, 55 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 0aa7f0a6..b223c5ad 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -555,6 +555,61 @@ struct RowsNodeOp : public UnaryNodeOp {
std::vector<size_t> indeces_;
};
+struct ColsNodeOp : public UnaryNodeOp {
+ template <typename ...Args>
+ ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args ...args)
+ : UnaryNodeOp(a, keywords::shape=newShape(a, indeces), args...),
+ indeces_(indeces) {
+ }
+
+ NodeOps forwardOps() {
+ // @TODO: solve this with a tensor!
+
+ return {
+ NodeOp(CopyCols(val_,
+ children_[0]->val(),
+ indeces_))
+ };
+ }
+
+ NodeOps backwardOps() {
+ return {
+ NodeOp(PasteCols(children_[0]->grad(),
+ adj_,
+ indeces_))
+ };
+ }
+
+ template <class ...Args>
+ Shape newShape(Expr a, const std::vector<size_t>& indeces) {
+ Shape shape = a->shape();
+ shape.set(1, indeces.size());
+ return shape;
+ }
+
+ const std::string type() {
+ return "cols";
+ }
+
+ const std::string color() {
+ return "orange";
+ }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ size_t seed = NaryNodeOp::hash();
+ for(auto i : indeces_)
+ boost::hash_combine(seed, i);
+ hash_ = seed;
+ }
+ return hash_;
+ }
+
+
+ std::vector<size_t> indeces_;
+};
+
+
struct TransposeNodeOp : public UnaryNodeOp {
template <typename ...Args>
TransposeNodeOp(Expr a, Args ...args)