diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-02-13 20:35:36 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-02-13 20:35:36 +0300 |
commit | a562a5dcc7f8c8ba72f2848f364e9965b61188a3 (patch) | |
tree | da39f56d37dbb8420dabb8636220ae0b6738ff2d /src/graph/node_operators_unary.h | |
parent | 7a095eadcd68d2a83095203ce4af709a14a8f9e4 (diff) |
bn
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 42 |
1 files changed, 35 insertions, 7 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 5a6da1b4..80ee0801 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -423,19 +423,17 @@ struct ExpNodeOp : public UnaryNodeOp { }; -struct PowNodeOp : public UnaryNodeOp { - float exponent_; +struct SqrtNodeOp : public UnaryNodeOp { float epsilon_; template <typename ...Args> - PowNodeOp(Expr a, float exponent, float epsilon, Args ...args) + SqrtNodeOp(Expr a, float epsilon, Args ...args) : UnaryNodeOp(a, args...), - exponent_(exponent), epsilon_(epsilon) { } NodeOps forwardOps() { return { - NodeOp(Element(_1 = Pow(epsilon_ + _2, exponent_), + NodeOp(Element(_1 = Sqrt(_2 + epsilon_), val_, children_[0]->val())) }; @@ -443,7 +441,37 @@ struct PowNodeOp : public UnaryNodeOp { NodeOps backwardOps() { return { - NodeOp(Add(exponent_ * Pow(epsilon_ + _1, exponent_ - 1.f) * _2, + NodeOp(Add(0.5f * (1.f / _1) * _2, + children_[0]->grad(), + val_, + adj_)) + }; + } + + const std::string type() { + return "sqrt"; + } + +}; + +struct SquareNodeOp : public UnaryNodeOp { + float epsilon_; + + template <typename ...Args> + SquareNodeOp(Args ...args) + : UnaryNodeOp(args...) { } + + NodeOps forwardOps() { + return { + NodeOp(Element(_1 = _2 * _2, + val_, + children_[0]->val())) + }; + } + + NodeOps backwardOps() { + return { + NodeOp(Add(2.f * _1 * _2, children_[0]->grad(), children_[0]->val(), adj_)) @@ -451,7 +479,7 @@ struct PowNodeOp : public UnaryNodeOp { } const std::string type() { - return "pow"; + return "square"; } }; |