diff options
author | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2017-10-29 20:57:53 +0300 |
---|---|---|
committer | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2017-10-29 20:57:53 +0300 |
commit | 857289d1bc0afa14d1338f7b6a29cb614b03bcac (patch) | |
tree | fdc3d9926c8b9bd5a020297de8a6634d665e395e /src/graph/node_operators_unary.h | |
parent | a6d4f2a9ab1ed4b7f559715e9a5b70082f5efd32 (diff) |
Add PReLU
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 53 |
1 files changed, 49 insertions, 4 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 93c12a3c..819192dc 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -243,8 +243,8 @@ struct ReLUNodeOp : public UnaryNodeOp { * \f[ * f(x) = * \begin{cases} - * 0.01 & \text{if } x < 0 \\ - * x & \text{if } x \geq 0 + * 0.01 & \text{if } x \leq 0 \\ + * x & \text{if } x > 0 * \end{cases} * \f] * @@ -252,8 +252,8 @@ struct ReLUNodeOp : public UnaryNodeOp { * \f[ * f^\prime(x) = * \begin{cases} - * 0.01 & \text{if } x < 0 \\ - * 1 & \text{if } x \geq 0 + * 0.01 & \text{if } x \leq 0 \\ + * 1 & \text{if } x > 0 * \end{cases} * \f] */ @@ -274,6 +274,51 @@ struct LeakyReLUNodeOp : public UnaryNodeOp { }; /** + * Represents a <a + * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">parametric + * rectified linear unit</a> node in an expression graph. + * For \f$ \alpha = 0.01 \f$ (the default value) it is equivalent to Leaky + * ReLU. + * + * This node implements the activation function: + * \f[ + * f(x, \alpha) = + * \begin{cases} + * \alpha x & \text{if } x \leq 0 \\ + * x & \text{if } x > 0 + * \end{cases} + * \f] + * + * and its derivative: + * \f[ + * f^\prime(x, \alpha) = + * \begin{cases} + * \alpha & \text{if } x \leq 0 \\ + * 1 & \text{if } x > 0 + * \end{cases} + * \f] + */ +struct PReLUNodeOp : public UnaryNodeOp { + template <typename... Args> + PReLUNodeOp(float alpha, Args... args) + : UnaryNodeOp(args...), alpha_(alpha) {} + + NodeOps forwardOps() { + return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))}; + } + + NodeOps backwardOps() { + return {NodeOp( + Add(_1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))}; + } + + const std::string type() { return "PReLU"; } + +private: + float alpha_{0.01}; +}; + +/** * Represents a <a href="https://arxiv.org/pdf/1710.05941.pdf">swish</a> node * in an expression graph. * |