diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-04-15 11:33:29 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-04-15 11:33:29 +0300 |
commit | a124126c85571b02b49e379011e3dbfedffaf8a1 (patch) | |
tree | 17565a3c0b5d3bc90d2503c705888eceaa2595d3 /src/graph/node_operators_unary.h | |
parent | b7a5fe5cb921b82cc18b4b9ffc6d605548b9d617 (diff) |
cleaner batching, operator for subselecting columns
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 0aa7f0a6..b223c5ad 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -555,6 +555,61 @@ struct RowsNodeOp : public UnaryNodeOp { std::vector<size_t> indeces_; }; +struct ColsNodeOp : public UnaryNodeOp { + template <typename ...Args> + ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args ...args) + : UnaryNodeOp(a, keywords::shape=newShape(a, indeces), args...), + indeces_(indeces) { + } + + NodeOps forwardOps() { + // @TODO: solve this with a tensor! + + return { + NodeOp(CopyCols(val_, + children_[0]->val(), + indeces_)) + }; + } + + NodeOps backwardOps() { + return { + NodeOp(PasteCols(children_[0]->grad(), + adj_, + indeces_)) + }; + } + + template <class ...Args> + Shape newShape(Expr a, const std::vector<size_t>& indeces) { + Shape shape = a->shape(); + shape.set(1, indeces.size()); + return shape; + } + + const std::string type() { + return "cols"; + } + + const std::string color() { + return "orange"; + } + + virtual size_t hash() { + if(!hash_) { + size_t seed = NaryNodeOp::hash(); + for(auto i : indeces_) + boost::hash_combine(seed, i); + hash_ = seed; + } + return hash_; + } + + + std::vector<size_t> indeces_; +}; + + struct TransposeNodeOp : public UnaryNodeOp { template <typename ...Args> TransposeNodeOp(Expr a, Args ...args) |