Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-11-02 00:44:30 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-11-02 00:44:30 +0300
commit289c25cb2ca0df9077480f8e843e99ea2ef8f018 (patch)
tree89ad435c7feea3e14c6cd7245e00d0b84480e8fd /src/graph/node_operators_unary.h
parent0d20bb57c19290eb01c0e1d4ad539d94571b0f80 (diff)
parente7bc3c7e68d075c4f14c15f5d7a46925f14d02f9 (diff)
merge with develop
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h104
1 files changed, 83 insertions, 21 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 4457ad2e..844f8905 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -186,38 +186,98 @@ struct TanhNodeOp : public NaryNodeOp {
/**
* Represents a <a
-href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified
-linear</a> node
- * in an expression graph.
+ * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified
+ * linear</a> node in an expression graph.
*
- * This node implements the <a
-href="https://en.wikipedia.org/wiki/Activation_function">activation function</a>
- * \f$f(x) = \max(0, x)\f$ and its derivative:
- *
- \f[
- f^\prime(x) =
- \begin{cases}
- 0 & \text{if } x \leq 0 \\
- 1 & \text{if } x > 0
- \end{cases}
-\f]
+ * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and
+ * its derivative:
+ * \f[
+ * f^\prime(x) =
+ * \begin{cases}
+ * 0 & \text{if } x \leq 0 \\
+ * 1 & \text{if } x > 0
+ * \end{cases}
+ * \f]
*/
struct ReLUNodeOp : public UnaryNodeOp {
template <typename... Args>
ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = ReLU(_2), val_, child(0)->val()))};
+ // f(x) = max(0, x)
+ return {NodeOp(Element(_1 = ReLU(_2),
+ val_, // _1 := f(x) to be calculated
+ child(0)->val() // _2 := x
+ ))};
}
NodeOps backwardOps() {
- return {NodeOp(
- Add(_1 * ReLUback(_2), child(0)->grad(), adj_, child(0)->val()))};
+ // dJ/dx += dJ/df * binarystep(x)
+ return {NodeOp(Add(_1 * ReLUback(_2),
+ child(0)->grad(), // dJ/dx
+ adj_, // _1 := dJ/df
+ child(0)->val() // _2 := f(x) = max(0, x)
+ ))};
}
const std::string type() { return "ReLU"; }
};
+/**
+ * Represents a <a
+ * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">parametric
+ * rectified linear unit</a> node in an expression graph.
+ * For \f$ \alpha = 0.01 \f$ (the default value) it is equivalent to Leaky
+ * ReLU.
+ *
+ * This node implements the activation function:
+ * \f[
+ * f(x, \alpha) =
+ * \begin{cases}
+ * \alpha x & \text{if } x \leq 0 \\
+ * x & \text{if } x > 0
+ * \end{cases}
+ * \f]
+ *
+ * and its derivative:
+ * \f[
+ * f^\prime(x, \alpha) =
+ * \begin{cases}
+ * \alpha & \text{if } x \leq 0 \\
+ * 1 & \text{if } x > 0
+ * \end{cases}
+ * \f]
+ */
+struct PReLUNodeOp : public UnaryNodeOp {
+ template <typename... Args>
+ PReLUNodeOp(float alpha, Args... args)
+ : UnaryNodeOp(args...), alpha_(alpha) {}
+
+ NodeOps forwardOps() {
+ return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(Add(
+ _1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))};
+ }
+
+ const std::string type() { return "PReLU"; }
+
+private:
+ float alpha_{0.01};
+};
+
+/**
+ * Represents a <a href="https://arxiv.org/pdf/1710.05941.pdf">swish</a> node
+ * in an expression graph.
+ *
+ * This node implements the activation function
+ * \f$ f(x) = x \cdot \sigma(x) \f$
+ * and its derivative
+ * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ .
+ *
+ */
struct SwishNodeOp : public UnaryNodeOp {
template <typename... Args>
SwishNodeOp(Args... args) : UnaryNodeOp(args...) {}
@@ -227,11 +287,13 @@ struct SwishNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
+ // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) )
return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)),
- child(0)->grad(),
- adj_,
- child(0)->val(),
- val_))};
+ child(0)->grad(), // dJ/dx
+ adj_, // _1 := dJ/df
+ child(0)->val(), // _2 := x
+ val_ // _3 := f(x) = x*sigma(x)
+ ))};
}
const std::string type() { return "swish"; }