From 6571741b735e1fed9f999476361f0d68de6f1118 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Fri, 11 May 2018 22:03:52 -0700 Subject: clipping gemm --- src/graph/node_operators_unary.h | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 273adf44..8f1fd2f7 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -99,6 +99,46 @@ public: } }; +struct ClipNodeOp : public UnaryNodeOp { +private: + float clip_{0}; + +public: + ClipNodeOp(Expr a, float clip) : UnaryNodeOp(a), clip_{clip} {} + + NodeOps forwardOps() { + using namespace functional; + return {NodeOp(Element(_1 = clip(_2, clip_), val_, child(0)->val()))}; + } + + NodeOps backwardOps() { + using namespace functional; + // @TODO: is this correct? + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; + } + + const std::string type() { return "clip"; } + + virtual size_t hash() { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + boost::hash_combine(hash_, clip_); + } + return hash_; + } + + virtual bool equal(Expr node) { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast(node); + if(!cnode) + return false; + if(clip_ != cnode->clip_) + return false; + return true; + } +}; + struct LogitNodeOp : public UnaryNodeOp { LogitNodeOp(Expr a) : UnaryNodeOp(a) {} -- cgit v1.2.3