From 505353d76dc95fa3df703b4ff908fa489eb0e85b Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Sun, 29 Oct 2017 14:06:53 +0000 Subject: Add comments to ReLU i Swish --- src/graph/node_operators_unary.h | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 38241258..6a78da67 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -199,9 +199,8 @@ struct TanhNodeOp : public NaryNodeOp { * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified * linear node in an expression graph. * - * This node implements the activationfunction \f$ f(x) = \max(0, x) \f$ and + * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and * its derivative: - * * \f[ * f^\prime(x) = * \begin{cases} @@ -215,12 +214,20 @@ struct ReLUNodeOp : public UnaryNodeOp { ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = ReLU(_2), val_, child(0)->val()))}; + // f(x) = max(0, x) + return {NodeOp(Element(_1 = ReLU(_2), + val_, // _1 := f(x) to be calculated + child(0)->val() // _2 := x + ))}; } NodeOps backwardOps() { - return {NodeOp( - Add(_1 * ReLUback(_2), child(0)->grad(), adj_, child(0)->val()))}; + // dJ/dx += dJ/df * binarystep(x) + return {NodeOp(Add(_1 * ReLUback(_2), + child(0)->grad(), // dJ/dx + adj_, // _1 := dJ/df + child(0)->val() // _2 := f(x) = max(0, x) + ))}; } const std::string type() { return "ReLU"; } @@ -245,11 +252,13 @@ struct SwishNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { + // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), - child(0)->grad(), - adj_, - child(0)->val(), - val_))}; + child(0)->grad(), // dJ/dx + adj_, // _1 := dJ/df + child(0)->val(), // _2 := x + val_ // _3 := f(x) = x*sigma(x) + ))}; } const std::string type() { return "swish"; } -- cgit v1.2.3