Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-23 14:26:38 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-23 14:26:38 +0300
commit537fccc3e81831fb0aa8baaaaef2858c96c346ec (patch)
tree57a48f2e1ff7894289587255f503f3ab7744203f /src/graph/node_operators_unary.h
parent15947c6061c009679b0a00ca873d768c1159cd6f (diff)
parent9023667939b0fdd645f971cdeb0ab4e764b07057 (diff)
Merge branch 'master' of https://github.com/marian-nmt/marian-dev into charS2S
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h95
1 files changed, 75 insertions, 20 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index faf21dee..8390a0c2 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -39,6 +39,25 @@ public:
}
const std::string type() { return "scalar_add"; }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, scalar_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ScalarAddNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
};
struct ScalarMultNodeOp : public UnaryNodeOp {
@@ -61,6 +80,25 @@ public:
}
const std::string type() { return "scalar_add"; }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, scalar_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ScalarMultNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
};
struct LogitNodeOp : public UnaryNodeOp {
@@ -256,6 +294,25 @@ struct PReLUNodeOp : public UnaryNodeOp {
const std::string type() { return "PReLU"; }
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, alpha_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<PReLUNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(alpha_ != cnode->alpha_)
+ return false;
+ return true;
+ }
+
private:
float alpha_{0.01};
};
@@ -546,8 +603,6 @@ struct SqrtNodeOp : public UnaryNodeOp {
};
struct SquareNodeOp : public UnaryNodeOp {
- float epsilon_;
-
template <typename... Args>
SquareNodeOp(Args... args) : UnaryNodeOp(args...) {}
@@ -586,16 +641,16 @@ struct RowsNodeOp : public UnaryNodeOp {
template <typename... Args>
RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args)
: UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...),
- indeces_(indeces) {}
+ indices_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
- return {NodeOp(CopyRows(val_, child(0)->val(), indeces_))};
+ return {NodeOp(CopyRows(val_, child(0)->val(), indices_))};
}
NodeOps backwardOps() {
- return {NodeOp(PasteRows(child(0)->grad(), adj_, indeces_))};
+ return {NodeOp(PasteRows(child(0)->grad(), adj_, indices_))};
}
template <class... Args>
@@ -614,7 +669,7 @@ struct RowsNodeOp : public UnaryNodeOp {
virtual size_t hash() {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
- for(auto i : indeces_)
+ for(auto i : indices_)
boost::hash_combine(seed, i);
hash_ = seed;
}
@@ -627,28 +682,28 @@ struct RowsNodeOp : public UnaryNodeOp {
Ptr<RowsNodeOp> cnode = std::dynamic_pointer_cast<RowsNodeOp>(node);
if(!cnode)
return false;
- if(indeces_ != cnode->indeces_)
+ if(indices_ != cnode->indices_)
return false;
return true;
}
- std::vector<size_t> indeces_;
+ std::vector<size_t> indices_;
};
struct ColsNodeOp : public UnaryNodeOp {
template <typename... Args>
ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args)
: UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...),
- indeces_(indeces) {}
+ indices_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
- return {NodeOp(CopyCols(val_, child(0)->val(), indeces_))};
+ return {NodeOp(CopyCols(val_, child(0)->val(), indices_))};
}
NodeOps backwardOps() {
- return {NodeOp(PasteCols(child(0)->grad(), adj_, indeces_))};
+ return {NodeOp(PasteCols(child(0)->grad(), adj_, indices_))};
}
template <class... Args>
@@ -665,7 +720,7 @@ struct ColsNodeOp : public UnaryNodeOp {
virtual size_t hash() {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
- for(auto i : indeces_)
+ for(auto i : indices_)
boost::hash_combine(seed, i);
hash_ = seed;
}
@@ -678,27 +733,27 @@ struct ColsNodeOp : public UnaryNodeOp {
Ptr<ColsNodeOp> cnode = std::dynamic_pointer_cast<ColsNodeOp>(node);
if(!cnode)
return false;
- if(indeces_ != cnode->indeces_)
+ if(indices_ != cnode->indices_)
return false;
return true;
}
- std::vector<size_t> indeces_;
+ std::vector<size_t> indices_;
};
struct SelectNodeOp : public UnaryNodeOp {
SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces)
: UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)),
- indeces_(indeces) {}
+ indices_(indeces) {}
NodeOps forwardOps() {
return {NodeOp(
- Select(graph()->allocator(), val_, child(0)->val(), axis_, indeces_))};
+ Select(graph()->allocator(), val_, child(0)->val(), axis_, indices_))};
}
NodeOps backwardOps() {
return {NodeOp(
- Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indeces_))};
+ Insert(graph()->allocator(), child(0)->grad(), adj_, axis_, indices_))};
}
Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) {
@@ -716,7 +771,7 @@ struct SelectNodeOp : public UnaryNodeOp {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
boost::hash_combine(seed, axis_);
- for(auto i : indeces_)
+ for(auto i : indices_)
boost::hash_combine(seed, i);
hash_ = seed;
}
@@ -731,12 +786,12 @@ struct SelectNodeOp : public UnaryNodeOp {
return false;
if(axis_ != cnode->axis_)
return false;
- if(indeces_ != cnode->indeces_)
+ if(indices_ != cnode->indices_)
return false;
return true;
}
- std::vector<size_t> indeces_;
+ std::vector<size_t> indices_;
int axis_{0};
};