Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-09-19 03:20:28 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-09-19 03:20:28 +0300
commitaed9f48660bdf504a2b0db6ad6261f17e679c591 (patch)
treebc1196543ff98799f417521799e77f6f32f063f7 /src/graph/node_operators_unary.h
parent9f79c919dc886408d7180e207f60a2a3aa9af3d6 (diff)
more complrex transformer
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h51
1 files changed, 51 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 76b5f376..a269e501 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -598,6 +598,57 @@ struct TransposeNodeOp : public UnaryNodeOp {
const std::string color() { return "orange"; }
};
+struct Transpose4DNodeOp : public UnaryNodeOp {
+ Shape permute_;
+
+ Transpose4DNodeOp(Expr a, Shape permute)
+ : UnaryNodeOp(a, keywords::shape = newShape(a, permute)),
+ permute_{permute} {}
+
+ NodeOps forwardOps() {
+ return { NodeOp(Transpose4D(val_, child(0)->val(), permute_)) };
+ }
+
+ NodeOps backwardOps() {
+ return { NodeOp(Transpose4D(child(0)->grad(), adj_, permute_)) };
+ }
+
+ template <class... Args>
+ Shape newShape(Expr a, Shape permute) {
+ Shape shape;
+
+ for(int i = 0; i < 4; ++i)
+ shape.set(i, a->shape()[permute[i]]);
+
+ return shape;
+ }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ size_t seed = NaryNodeOp::hash();
+ for(auto s : permute_)
+ boost::hash_combine(seed, s);
+ hash_ = seed;
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ Ptr<Transpose4DNodeOp> cnode = std::dynamic_pointer_cast<Transpose4DNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(permute_ != cnode->permute_)
+ return false;
+ return true;
+ }
+
+ const std::string type() { return "transpose4d"; }
+
+ const std::string color() { return "orange"; }
+};
+
class ReshapeNodeOp : public UnaryNodeOp {
private:
Expr reshapee_;