diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-06-08 00:34:39 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-06-08 00:34:39 +0300 |
commit | 7bb558ecfcdfef5c629f5a9d85ea2b4680bb60aa (patch) | |
tree | 0664fc883448c68ad81b91e4b76bef319d45b638 /src/graph/node_operators_unary.h | |
parent | 68d61a662294cb3f26b3935da95a8ce1c404c293 (diff) | |
parent | ac21830517e75e31a0bca3b071292acff0d9610d (diff) |
Merge branch 'master' into jonathac/windows_build
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index e14f6546..259e6072 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -99,6 +99,45 @@ public: } }; +struct ClipNodeOp : public UnaryNodeOp { +private: + float clip_{0}; + +public: + ClipNodeOp(Expr a, float clip) : UnaryNodeOp(a), clip_{clip} {} + + NodeOps forwardOps() { + using namespace functional; + return {NodeOp(Element(_1 = clip(_2, clip_), val_, child(0)->val()))}; + } + + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(bump(_1, clip_) * _2, child(0)->grad(), child(0)->val(), adj_))}; + } + + const std::string type() { return "clip"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, clip_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<ClipNodeOp>(node); + if(!cnode) + return false; + if(clip_ != cnode->clip_) + return false; + return true; + } +}; + struct LogitNodeOp : public UnaryNodeOp { LogitNodeOp(Expr a) : UnaryNodeOp(a) {} |