diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-11-02 00:44:30 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-11-02 00:44:30 +0300 |
commit | 289c25cb2ca0df9077480f8e843e99ea2ef8f018 (patch) | |
tree | 89ad435c7feea3e14c6cd7245e00d0b84480e8fd /src/graph | |
parent | 0d20bb57c19290eb01c0e1d4ad539d94571b0f80 (diff) | |
parent | e7bc3c7e68d075c4f14c15f5d7a46925f14d02f9 (diff) |
merge with develop
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/expression_operators.cu | 16 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 6 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 104 |
3 files changed, 105 insertions, 21 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index deed4815..98570862 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -20,6 +20,14 @@ Expr relu(Expr a) { return Expression<ReLUNodeOp>(a); } +Expr leakyrelu(Expr a) { + return Expression<PReLUNodeOp>(0.01f, a); +} + +Expr prelu(Expr a, float alpha) { + return Expression<PReLUNodeOp>(alpha, a); +} + Expr log(Expr a) { return Expression<LogNodeOp>(a); }; @@ -238,6 +246,14 @@ Expr relu(const std::vector<Expr>&) { ABORT("Not implemented"); } +Expr leakyrelu(const std::vector<Expr>&) { + ABORT("Not implemented"); +} + +Expr prelu(const std::vector<Expr>&, float alpha) { + ABORT("Not implemented"); +} + Expr sqrt(Expr a, float eps) { return Expression<SqrtNodeOp>(a, eps); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 8daf8ef9..f6721b52 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -24,6 +24,12 @@ Expr tanh(Args... args) { Expr relu(Expr a); Expr relu(const std::vector<Expr>&); +Expr leakyrelu(Expr a); +Expr leakyrelu(const std::vector<Expr>&); + +Expr prelu(Expr a, float alpha = 0.01); +Expr prelu(const std::vector<Expr>&, float alpha = 0.01); + Expr log(Expr a); Expr exp(Expr a); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 4457ad2e..844f8905 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -186,38 +186,98 @@ struct TanhNodeOp : public NaryNodeOp { /** * Represents a <a -href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified -linear</a> node - * in an expression graph. + * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified + * linear</a> node in an expression graph. * - * This node implements the <a -href="https://en.wikipedia.org/wiki/Activation_function">activation function</a> - * \f$f(x) = \max(0, x)\f$ and its derivative: - * - \f[ - f^\prime(x) = - \begin{cases} - 0 & \text{if } x \leq 0 \\ - 1 & \text{if } x > 0 - \end{cases} -\f] + * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and + * its derivative: + * \f[ + * f^\prime(x) = + * \begin{cases} + * 0 & \text{if } x \leq 0 \\ + * 1 & \text{if } x > 0 + * \end{cases} + * \f] */ struct ReLUNodeOp : public UnaryNodeOp { template <typename... Args> ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = ReLU(_2), val_, child(0)->val()))}; + // f(x) = max(0, x) + return {NodeOp(Element(_1 = ReLU(_2), + val_, // _1 := f(x) to be calculated + child(0)->val() // _2 := x + ))}; } NodeOps backwardOps() { - return {NodeOp( - Add(_1 * ReLUback(_2), child(0)->grad(), adj_, child(0)->val()))}; + // dJ/dx += dJ/df * binarystep(x) + return {NodeOp(Add(_1 * ReLUback(_2), + child(0)->grad(), // dJ/dx + adj_, // _1 := dJ/df + child(0)->val() // _2 := f(x) = max(0, x) + ))}; } const std::string type() { return "ReLU"; } }; +/** + * Represents a <a + * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">parametric + * rectified linear unit</a> node in an expression graph. + * For \f$ \alpha = 0.01 \f$ (the default value) it is equivalent to Leaky + * ReLU. + * + * This node implements the activation function: + * \f[ + * f(x, \alpha) = + * \begin{cases} + * \alpha x & \text{if } x \leq 0 \\ + * x & \text{if } x > 0 + * \end{cases} + * \f] + * + * and its derivative: + * \f[ + * f^\prime(x, \alpha) = + * \begin{cases} + * \alpha & \text{if } x \leq 0 \\ + * 1 & \text{if } x > 0 + * \end{cases} + * \f] + */ +struct PReLUNodeOp : public UnaryNodeOp { + template <typename... Args> + PReLUNodeOp(float alpha, Args... args) + : UnaryNodeOp(args...), alpha_(alpha) {} + + NodeOps forwardOps() { + return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))}; + } + + NodeOps backwardOps() { + return {NodeOp(Add( + _1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))}; + } + + const std::string type() { return "PReLU"; } + +private: + float alpha_{0.01}; +}; + +/** + * Represents a <a href="https://arxiv.org/pdf/1710.05941.pdf">swish</a> node + * in an expression graph. + * + * This node implements the activation function + * \f$ f(x) = x \cdot \sigma(x) \f$ + * and its derivative + * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ . + * + */ struct SwishNodeOp : public UnaryNodeOp { template <typename... Args> SwishNodeOp(Args... args) : UnaryNodeOp(args...) {} @@ -227,11 +287,13 @@ struct SwishNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { + // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), - child(0)->grad(), - adj_, - child(0)->val(), - val_))}; + child(0)->grad(), // dJ/dx + adj_, // _1 := dJ/df + child(0)->val(), // _2 := x + val_ // _3 := f(x) = x*sigma(x) + ))}; } const std::string type() { return "swish"; } |