Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-08 00:34:39 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-08 00:34:39 +0300
commit7bb558ecfcdfef5c629f5a9d85ea2b4680bb60aa (patch)
tree0664fc883448c68ad81b91e4b76bef319d45b638 /src/graph/node_operators_unary.h
parent68d61a662294cb3f26b3935da95a8ce1c404c293 (diff)
parentac21830517e75e31a0bca3b071292acff0d9610d (diff)
Merge branch 'master' into jonathac/windows_build
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h39
1 files changed, 39 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index e14f6546..259e6072 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -99,6 +99,45 @@ public:
}
};
+struct ClipNodeOp : public UnaryNodeOp {
+private:
+ float clip_{0};
+
+public:
+ ClipNodeOp(Expr a, float clip) : UnaryNodeOp(a), clip_{clip} {}
+
+ NodeOps forwardOps() {
+ using namespace functional;
+ return {NodeOp(Element(_1 = clip(_2, clip_), val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+ using namespace functional;
+ return {NodeOp(Add(bump(_1, clip_) * _2, child(0)->grad(), child(0)->val(), adj_))};
+ }
+
+ const std::string type() { return "clip"; }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, clip_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ClipNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(clip_ != cnode->clip_)
+ return false;
+ return true;
+ }
+};
+
struct LogitNodeOp : public UnaryNodeOp {
LogitNodeOp(Expr a) : UnaryNodeOp(a) {}