Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-10-29 17:43:28 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-10-29 20:26:35 +0300
commit2038b3cb2dbfa81876c71de6eb4bd953b8ef7587 (patch)
treea554f5fbf690c9687e62f067aa75878632b01571 /src/graph/node_operators_unary.h
parent505353d76dc95fa3df703b4ff908fa489eb0e85b (diff)
Add LeakyReLU
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h40
1 files changed, 40 insertions, 0 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 6a78da67..93c12a3c 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -234,6 +234,46 @@ struct ReLUNodeOp : public UnaryNodeOp {
};
/**
+ * Represents a <a
+ * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">leaky
+ * rectified linear unit</a> node in an expression graph.
+ * It is equivalent to the parametric ReLU with \f$ \alpha = 0.01 \f$.
+ *
+ * This node implements the activation function:
+ * \f[
+ * f(x) =
+ * \begin{cases}
+ * 0.01 & \text{if } x < 0 \\
+ * x & \text{if } x \geq 0
+ * \end{cases}
+ * \f]
+ *
+ * and its derivative:
+ * \f[
+ * f^\prime(x) =
+ * \begin{cases}
+ * 0.01 & \text{if } x < 0 \\
+ * 1 & \text{if } x \geq 0
+ * \end{cases}
+ * \f]
+ */
+struct LeakyReLUNodeOp : public UnaryNodeOp {
+ template <typename... Args>
+ LeakyReLUNodeOp(Args... args) : UnaryNodeOp(args...) {}
+
+ NodeOps forwardOps() {
+ return {NodeOp(Element(_1 = LeakyReLU(_2), val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(
+ Add(_1 * LeakyReLUback(_2), child(0)->grad(), adj_, child(0)->val()))};
+ }
+
+ const std::string type() { return "LeakyReLU"; }
+};
+
+/**
* Represents a <a href="https://arxiv.org/pdf/1710.05941.pdf">swish</a> node
* in an expression graph.
*