Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-11-02 22:37:43 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-11-02 22:37:43 +0300
commitca64c429e4aa4dcd49b68bab6a7d744fe06b44c2 (patch)
treed6469621854e3a44ec43527ac843d67d9eb1f586 /src/graph/node_operators_unary.h
parentcd6fca847dd9aa0b5236306beaaa8f778e2d8726 (diff)
new functional templates
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h59
1 files changed, 43 insertions, 16 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index fcf1c290..05294bee 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -4,8 +4,8 @@
#include "graph/node.h"
#include "kernels/sparse.h"
#include "kernels/tensor_operators.h"
-#include "kernels/thrust_functions.h"
#include "tensors/tensor.h"
+#include "gpu/functions.h"
#ifdef CUDNN
@@ -51,6 +51,7 @@ public:
: UnaryNodeOp(a, args...), scalar_{scalar} {}
NodeOps forwardOps() {
+ using namespace functional;
return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))};
}
@@ -69,10 +70,12 @@ public:
: UnaryNodeOp(a, args...), scalar_{scalar} {}
NodeOps forwardOps() {
+ using namespace functional;
return {NodeOp(Element(_1 = scalar_ * _2, val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))};
}
@@ -84,10 +87,12 @@ struct LogitNodeOp : public UnaryNodeOp {
LogitNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = Sigma(_2), val_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Element(_1 = logit(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {NodeOp(Add(_1 * _2 * (1.0f - _2), child(0)->grad(), adj_, val_))};
}
@@ -145,13 +150,14 @@ struct TanhNodeOp : public NaryNodeOp {
}
NodeOps forwardOps() {
+ using namespace functional;
switch(children_.size()) {
- case 1: return {NodeOp(Element(_1 = Tanh(_2), val_, child(0)->val()))};
+ case 1: return {NodeOp(Element(_1 = tanh(_2), val_, child(0)->val()))};
case 2:
return {NodeOp(Element(
- _1 = Tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))};
+ _1 = tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))};
case 3:
- return {NodeOp(Element(_1 = Tanh(_2 + _3 + _4),
+ return {NodeOp(Element(_1 = tanh(_2 + _3 + _4),
val_,
child(0)->val(),
child(1)->val(),
@@ -164,13 +170,14 @@ struct TanhNodeOp : public NaryNodeOp {
child(1)->val(),
child(2)->val());
for(int i = 3; i < children_.size(); ++i)
- Element(_1 += _2, val_, child(i)->val());
- Element(_1 = Tanh(_1), val_);)
+ Element(_1 = _1 + _2, val_, child(i)->val());
+ Element(_1 = tanh(_1), val_);)
};
}
}
NodeOps backwardOps() {
+ using namespace functional;
NodeOps ops;
for(int i = 0; i < children_.size(); i++) {
ops.push_back(
@@ -205,6 +212,7 @@ struct ReLUNodeOp : public UnaryNodeOp {
NodeOps forwardOps() {
// f(x) = max(0, x)
+ using namespace functional;
return {NodeOp(Element(_1 = ReLU(_2),
val_, // _1 := f(x) to be calculated
child(0)->val() // _2 := x
@@ -212,6 +220,7 @@ struct ReLUNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
+ using namespace functional;
// dJ/dx += dJ/df * binarystep(x)
return {NodeOp(Add(_1 * ReLUback(_2),
child(0)->grad(), // dJ/dx
@@ -254,10 +263,12 @@ struct PReLUNodeOp : public UnaryNodeOp {
: UnaryNodeOp(args...), alpha_(alpha) {}
NodeOps forwardOps() {
+ using namespace functional;
return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {NodeOp(Add(
_1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))};
}
@@ -283,12 +294,14 @@ struct SwishNodeOp : public UnaryNodeOp {
SwishNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = _2 * Sigma(_2), val_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Element(_1 = _2 * logit(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
// dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) )
- return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)),
+ return {NodeOp(Add(_1 * (_3 + logit(_2) * (1.f - _3)),
child(0)->grad(), // dJ/dx
adj_, // _1 := dJ/df
child(0)->val(), // _2 := x
@@ -424,6 +437,7 @@ struct MeanNodeOp : public UnaryNodeOp {
: UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {}
NodeOps forwardOps() {
+ using namespace functional;
int left = child(0)->shape().elements() / val_->shape().elements();
float scale = 1.f / left;
@@ -431,6 +445,7 @@ struct MeanNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
+ using namespace functional;
int left = child(0)->shape().elements() / val_->shape().elements();
float scale = 1.f / left;
@@ -474,10 +489,12 @@ struct LogNodeOp : public UnaryNodeOp {
LogNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = Log(_2), val_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Element(_1 = log(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {
NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))};
}
@@ -490,11 +507,13 @@ struct ExpNodeOp : public UnaryNodeOp {
ExpNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = Exp(_2), val_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Element(_1 = exp(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
- return {NodeOp(Add(_1 * Exp(_2), child(0)->grad(), adj_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Add(_1 * exp(_2), child(0)->grad(), adj_, child(0)->val()))};
}
const std::string type() { return "exp"; }
@@ -508,10 +527,12 @@ struct SqrtNodeOp : public UnaryNodeOp {
: UnaryNodeOp(a, args...), epsilon_(epsilon) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = Sqrt(_2 + epsilon_), val_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Element(_1 = sqrt(_2 + epsilon_), val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {NodeOp(Add(0.5f * (1.f / _1) * _2, child(0)->grad(), val_, adj_))};
}
@@ -545,10 +566,12 @@ struct SquareNodeOp : public UnaryNodeOp {
SquareNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
+ using namespace functional;
return {NodeOp(Element(_1 = _2 * _2, val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {
NodeOp(Add(2.f * _1 * _2, child(0)->grad(), child(0)->val(), adj_))};
}
@@ -561,10 +584,14 @@ struct NegNodeOp : public UnaryNodeOp {
NegNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
+ using namespace functional;
return {NodeOp(Element(_1 = -_2, val_, child(0)->val()))};
}
- NodeOps backwardOps() { return {NodeOp(Add(-_1, child(0)->grad(), adj_))}; }
+ NodeOps backwardOps() {
+ using namespace functional;
+ return {NodeOp(Add(-_1, child(0)->grad(), adj_))};
+ }
const std::string type() { return "-"; }
};
@@ -974,13 +1001,13 @@ struct ShiftNodeOp : public UnaryNodeOp {
// void forward() {
// sparse::LfaForward(val_, child(0)->val(), child(1)->val(), lf_);
// // val = x + ln(p + eps)
-// Element(_1 = (Log(_1 + eps_) + _2), val_, child(0)->val());
+// Element(_1 = (log(_1 + eps_) + _2), val_, child(0)->val());
// }
//
// void backward() {
// Add(_1, child(0)->grad(), adj_);
// // adj' = adj / (p + eps) = adj / exp(val - x)
-// Element(_1 = _1 / Exp(_2 - _3), adj_, val_, child(0)->val());
+// Element(_1 = _1 / exp(_2 - _3), adj_, val_, child(0)->val());
// sparse::LfaBackward(child(1)->grad(), adj_, lf_);
// }
//