Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-10-29 20:57:53 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-10-29 20:57:53 +0300
commit857289d1bc0afa14d1338f7b6a29cb614b03bcac (patch)
treefdc3d9926c8b9bd5a020297de8a6634d665e395e /src/graph/node_operators_unary.h
parenta6d4f2a9ab1ed4b7f559715e9a5b70082f5efd32 (diff)
Add PReLU
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h53
1 files changed, 49 insertions, 4 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 93c12a3c..819192dc 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -243,8 +243,8 @@ struct ReLUNodeOp : public UnaryNodeOp {
* \f[
* f(x) =
* \begin{cases}
- * 0.01 & \text{if } x < 0 \\
- * x & \text{if } x \geq 0
+ * 0.01 & \text{if } x \leq 0 \\
+ * x & \text{if } x > 0
* \end{cases}
* \f]
*
@@ -252,8 +252,8 @@ struct ReLUNodeOp : public UnaryNodeOp {
* \f[
* f^\prime(x) =
* \begin{cases}
- * 0.01 & \text{if } x < 0 \\
- * 1 & \text{if } x \geq 0
+ * 0.01 & \text{if } x \leq 0 \\
+ * 1 & \text{if } x > 0
* \end{cases}
* \f]
*/
@@ -274,6 +274,51 @@ struct LeakyReLUNodeOp : public UnaryNodeOp {
};
/**
+ * Represents a <a
+ * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">parametric
+ * rectified linear unit</a> node in an expression graph.
+ * For \f$ \alpha = 0.01 \f$ (the default value) it is equivalent to Leaky
+ * ReLU.
+ *
+ * This node implements the activation function:
+ * \f[
+ * f(x, \alpha) =
+ * \begin{cases}
+ * \alpha x & \text{if } x \leq 0 \\
+ * x & \text{if } x > 0
+ * \end{cases}
+ * \f]
+ *
+ * and its derivative:
+ * \f[
+ * f^\prime(x, \alpha) =
+ * \begin{cases}
+ * \alpha & \text{if } x \leq 0 \\
+ * 1 & \text{if } x > 0
+ * \end{cases}
+ * \f]
+ */
+struct PReLUNodeOp : public UnaryNodeOp {
+ template <typename... Args>
+ PReLUNodeOp(float alpha, Args... args)
+ : UnaryNodeOp(args...), alpha_(alpha) {}
+
+ NodeOps forwardOps() {
+ return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(
+ Add(_1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))};
+ }
+
+ const std::string type() { return "PReLU"; }
+
+private:
+ float alpha_{0.01};
+};
+
+/**
* Represents a <a href="https://arxiv.org/pdf/1710.05941.pdf">swish</a> node
* in an expression graph.
*