diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-11-02 22:37:43 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-11-02 22:37:43 +0300 |
commit | ca64c429e4aa4dcd49b68bab6a7d744fe06b44c2 (patch) | |
tree | d6469621854e3a44ec43527ac843d67d9eb1f586 /src/graph/node_operators_unary.h | |
parent | cd6fca847dd9aa0b5236306beaaa8f778e2d8726 (diff) |
new functional templates
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 59 |
1 files changed, 43 insertions, 16 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index fcf1c290..05294bee 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -4,8 +4,8 @@ #include "graph/node.h" #include "kernels/sparse.h" #include "kernels/tensor_operators.h" -#include "kernels/thrust_functions.h" #include "tensors/tensor.h" +#include "gpu/functions.h" #ifdef CUDNN @@ -51,6 +51,7 @@ public: : UnaryNodeOp(a, args...), scalar_{scalar} {} NodeOps forwardOps() { + using namespace functional; return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))}; } @@ -69,10 +70,12 @@ public: : UnaryNodeOp(a, args...), scalar_{scalar} {} NodeOps forwardOps() { + using namespace functional; return {NodeOp(Element(_1 = scalar_ * _2, val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))}; } @@ -84,10 +87,12 @@ struct LogitNodeOp : public UnaryNodeOp { LogitNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = Sigma(_2), val_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Element(_1 = logit(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return {NodeOp(Add(_1 * _2 * (1.0f - _2), child(0)->grad(), adj_, val_))}; } @@ -145,13 +150,14 @@ struct TanhNodeOp : public NaryNodeOp { } NodeOps forwardOps() { + using namespace functional; switch(children_.size()) { - case 1: return {NodeOp(Element(_1 = Tanh(_2), val_, child(0)->val()))}; + case 1: return {NodeOp(Element(_1 = tanh(_2), val_, child(0)->val()))}; case 2: return {NodeOp(Element( - _1 = Tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))}; + _1 = tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))}; case 3: - return {NodeOp(Element(_1 = Tanh(_2 + _3 + _4), + return {NodeOp(Element(_1 = tanh(_2 + _3 + _4), val_, child(0)->val(), child(1)->val(), @@ -164,13 +170,14 @@ struct TanhNodeOp : public NaryNodeOp { child(1)->val(), child(2)->val()); for(int i = 3; i < children_.size(); ++i) - Element(_1 += _2, val_, child(i)->val()); - Element(_1 = Tanh(_1), val_);) + Element(_1 = _1 + _2, val_, child(i)->val()); + Element(_1 = tanh(_1), val_);) }; } } NodeOps backwardOps() { + using namespace functional; NodeOps ops; for(int i = 0; i < children_.size(); i++) { ops.push_back( @@ -205,6 +212,7 @@ struct ReLUNodeOp : public UnaryNodeOp { NodeOps forwardOps() { // f(x) = max(0, x) + using namespace functional; return {NodeOp(Element(_1 = ReLU(_2), val_, // _1 := f(x) to be calculated child(0)->val() // _2 := x @@ -212,6 +220,7 @@ struct ReLUNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { + using namespace functional; // dJ/dx += dJ/df * binarystep(x) return {NodeOp(Add(_1 * ReLUback(_2), child(0)->grad(), // dJ/dx @@ -254,10 +263,12 @@ struct PReLUNodeOp : public UnaryNodeOp { : UnaryNodeOp(args...), alpha_(alpha) {} NodeOps forwardOps() { + using namespace functional; return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return {NodeOp(Add( _1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))}; } @@ -283,12 +294,14 @@ struct SwishNodeOp : public UnaryNodeOp { SwishNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = _2 * Sigma(_2), val_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Element(_1 = _2 * logit(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) - return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), + return {NodeOp(Add(_1 * (_3 + logit(_2) * (1.f - _3)), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val(), // _2 := x @@ -424,6 +437,7 @@ struct MeanNodeOp : public UnaryNodeOp { : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} NodeOps forwardOps() { + using namespace functional; int left = child(0)->shape().elements() / val_->shape().elements(); float scale = 1.f / left; @@ -431,6 +445,7 @@ struct MeanNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { + using namespace functional; int left = child(0)->shape().elements() / val_->shape().elements(); float scale = 1.f / left; @@ -474,10 +489,12 @@ struct LogNodeOp : public UnaryNodeOp { LogNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = Log(_2), val_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Element(_1 = log(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return { NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))}; } @@ -490,11 +507,13 @@ struct ExpNodeOp : public UnaryNodeOp { ExpNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = Exp(_2), val_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Element(_1 = exp(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { - return {NodeOp(Add(_1 * Exp(_2), child(0)->grad(), adj_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Add(_1 * exp(_2), child(0)->grad(), adj_, child(0)->val()))}; } const std::string type() { return "exp"; } @@ -508,10 +527,12 @@ struct SqrtNodeOp : public UnaryNodeOp { : UnaryNodeOp(a, args...), epsilon_(epsilon) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = Sqrt(_2 + epsilon_), val_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Element(_1 = sqrt(_2 + epsilon_), val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return {NodeOp(Add(0.5f * (1.f / _1) * _2, child(0)->grad(), val_, adj_))}; } @@ -545,10 +566,12 @@ struct SquareNodeOp : public UnaryNodeOp { SquareNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; return {NodeOp(Element(_1 = _2 * _2, val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return { NodeOp(Add(2.f * _1 * _2, child(0)->grad(), child(0)->val(), adj_))}; } @@ -561,10 +584,14 @@ struct NegNodeOp : public UnaryNodeOp { NegNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; return {NodeOp(Element(_1 = -_2, val_, child(0)->val()))}; } - NodeOps backwardOps() { return {NodeOp(Add(-_1, child(0)->grad(), adj_))}; } + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(-_1, child(0)->grad(), adj_))}; + } const std::string type() { return "-"; } }; @@ -974,13 +1001,13 @@ struct ShiftNodeOp : public UnaryNodeOp { // void forward() { // sparse::LfaForward(val_, child(0)->val(), child(1)->val(), lf_); // // val = x + ln(p + eps) -// Element(_1 = (Log(_1 + eps_) + _2), val_, child(0)->val()); +// Element(_1 = (log(_1 + eps_) + _2), val_, child(0)->val()); // } // // void backward() { // Add(_1, child(0)->grad(), adj_); // // adj' = adj / (p + eps) = adj / exp(val - x) -// Element(_1 = _1 / Exp(_2 - _3), adj_, val_, child(0)->val()); +// Element(_1 = _1 / exp(_2 - _3), adj_, val_, child(0)->val()); // sparse::LfaBackward(child(1)->grad(), adj_, lf_); // } // |