Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-10-29 17:06:53 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-10-29 20:26:35 +0300
commit505353d76dc95fa3df703b4ff908fa489eb0e85b (patch)
tree22d29480d51fb09ec3c0acd0e00545743784a764 /src/graph/node_operators_unary.h
parent3d226d9bd526694ffecb39c6ba15cd2748ccf7ab (diff)
Add comments to ReLU i Swish
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h27
1 files changed, 18 insertions, 9 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 38241258..6a78da67 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -199,9 +199,8 @@ struct TanhNodeOp : public NaryNodeOp {
* href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified
* linear</a> node in an expression graph.
*
- * This node implements the activationfunction \f$ f(x) = \max(0, x) \f$ and
+ * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and
* its derivative:
- *
* \f[
* f^\prime(x) =
* \begin{cases}
@@ -215,12 +214,20 @@ struct ReLUNodeOp : public UnaryNodeOp {
ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = ReLU(_2), val_, child(0)->val()))};
+ // f(x) = max(0, x)
+ return {NodeOp(Element(_1 = ReLU(_2),
+ val_, // _1 := f(x) to be calculated
+ child(0)->val() // _2 := x
+ ))};
}
NodeOps backwardOps() {
- return {NodeOp(
- Add(_1 * ReLUback(_2), child(0)->grad(), adj_, child(0)->val()))};
+ // dJ/dx += dJ/df * binarystep(x)
+ return {NodeOp(Add(_1 * ReLUback(_2),
+ child(0)->grad(), // dJ/dx
+ adj_, // _1 := dJ/df
+ child(0)->val() // _2 := f(x) = max(0, x)
+ ))};
}
const std::string type() { return "ReLU"; }
@@ -245,11 +252,13 @@ struct SwishNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
+ // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) )
return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)),
- child(0)->grad(),
- adj_,
- child(0)->val(),
- val_))};
+ child(0)->grad(), // dJ/dx
+ adj_, // _1 := dJ/df
+ child(0)->val(), // _2 := x
+ val_ // _3 := f(x) = x*sigma(x)
+ ))};
}
const std::string type() { return "swish"; }