Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-10-29 20:57:53 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-10-29 20:57:53 +0300
commit857289d1bc0afa14d1338f7b6a29cb614b03bcac (patch)
treefdc3d9926c8b9bd5a020297de8a6634d665e395e
parenta6d4f2a9ab1ed4b7f559715e9a5b70082f5efd32 (diff)
Add PReLU
-rw-r--r--src/graph/expression_operators.cu8
-rw-r--r--src/graph/expression_operators.h3
-rw-r--r--src/graph/node_operators_unary.h53
-rw-r--r--src/kernels/thrust_functions.h34
-rw-r--r--src/layers/generic.h4
5 files changed, 97 insertions, 5 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu
index ee059566..c75c3ea0 100644
--- a/src/graph/expression_operators.cu
+++ b/src/graph/expression_operators.cu
@@ -24,6 +24,10 @@ Expr leakyrelu(Expr a) {
return Expression<LeakyReLUNodeOp>(a);
}
+Expr prelu(Expr a, float alpha) {
+ return Expression<PReLUNodeOp>(alpha, a);
+}
+
Expr log(Expr a) {
return Expression<LogNodeOp>(a);
};
@@ -214,6 +218,10 @@ Expr leakyrelu(const std::vector<Expr>&) {
ABORT("Not implemented");
}
+Expr prelu(const std::vector<Expr>&, float alpha) {
+ ABORT("Not implemented");
+}
+
Expr sqrt(Expr a, float eps) {
return Expression<SqrtNodeOp>(a, eps);
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index d17fe8b3..b53fb6e4 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -27,6 +27,9 @@ Expr relu(const std::vector<Expr>&);
Expr leakyrelu(Expr a);
Expr leakyrelu(const std::vector<Expr>&);
+Expr prelu(Expr a, float alpha = 0.01);
+Expr prelu(const std::vector<Expr>&, float alpha = 0.01);
+
Expr log(Expr a);
Expr exp(Expr a);
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 93c12a3c..819192dc 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -243,8 +243,8 @@ struct ReLUNodeOp : public UnaryNodeOp {
* \f[
* f(x) =
* \begin{cases}
- * 0.01 & \text{if } x < 0 \\
- * x & \text{if } x \geq 0
+ * 0.01 & \text{if } x \leq 0 \\
+ * x & \text{if } x > 0
* \end{cases}
* \f]
*
@@ -252,8 +252,8 @@ struct ReLUNodeOp : public UnaryNodeOp {
* \f[
* f^\prime(x) =
* \begin{cases}
- * 0.01 & \text{if } x < 0 \\
- * 1 & \text{if } x \geq 0
+ * 0.01 & \text{if } x \leq 0 \\
+ * 1 & \text{if } x > 0
* \end{cases}
* \f]
*/
@@ -274,6 +274,51 @@ struct LeakyReLUNodeOp : public UnaryNodeOp {
};
/**
+ * Represents a <a
+ * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">parametric
+ * rectified linear unit</a> node in an expression graph.
+ * For \f$ \alpha = 0.01 \f$ (the default value) it is equivalent to Leaky
+ * ReLU.
+ *
+ * This node implements the activation function:
+ * \f[
+ * f(x, \alpha) =
+ * \begin{cases}
+ * \alpha x & \text{if } x \leq 0 \\
+ * x & \text{if } x > 0
+ * \end{cases}
+ * \f]
+ *
+ * and its derivative:
+ * \f[
+ * f^\prime(x, \alpha) =
+ * \begin{cases}
+ * \alpha & \text{if } x \leq 0 \\
+ * 1 & \text{if } x > 0
+ * \end{cases}
+ * \f]
+ */
+struct PReLUNodeOp : public UnaryNodeOp {
+ template <typename... Args>
+ PReLUNodeOp(float alpha, Args... args)
+ : UnaryNodeOp(args...), alpha_(alpha) {}
+
+ NodeOps forwardOps() {
+ return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(
+ Add(_1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))};
+ }
+
+ const std::string type() { return "PReLU"; }
+
+private:
+ float alpha_{0.01};
+};
+
+/**
* Represents a <a href="https://arxiv.org/pdf/1710.05941.pdf">swish</a> node
* in an expression graph.
*
diff --git a/src/kernels/thrust_functions.h b/src/kernels/thrust_functions.h
index 3c46349d..d33f24b7 100644
--- a/src/kernels/thrust_functions.h
+++ b/src/kernels/thrust_functions.h
@@ -141,6 +141,40 @@ __host__ __device__
//*******************************************************************
template <typename T>
+struct binary_prelu : public thrust::binary_function<T, T, T> {
+ __host__ __device__ T operator()(const T &x, const T &alpha) const {
+ return x > 0.0f ? x : alpha * x;
+ }
+};
+
+template <typename T1, typename T2>
+__host__ __device__ actor<composite<binary_operator<binary_prelu>,
+ actor<T1>,
+ typename as_actor<T2>::type>>
+PReLU(const actor<T1> &_1, const T2 &_2) {
+ return compose(
+ binary_operator<binary_prelu>(), make_actor(_1), make_actor(_2));
+}
+
+template <typename T>
+struct binary_preluback : public thrust::binary_function<T, T, T> {
+ __host__ __device__ T operator()(const T &x, const T &alpha) const {
+ return x > 0.0f ? 1.0f : alpha;
+ }
+};
+
+template <typename T1, typename T2>
+__host__ __device__ actor<composite<binary_operator<binary_preluback>,
+ actor<T1>,
+ typename as_actor<T2>::type>>
+PReLUback(const actor<T1> &_1, const T2 &_2) {
+ return compose(
+ binary_operator<binary_preluback>(), make_actor(_1), make_actor(_2));
+}
+
+//*******************************************************************
+
+template <typename T>
__host__ __device__ int sgn(T val) {
return (float(0) < val) - (val < float(0));
}
diff --git a/src/layers/generic.h b/src/layers/generic.h
index ed1ab222..dea129d6 100644
--- a/src/layers/generic.h
+++ b/src/layers/generic.h
@@ -9,7 +9,7 @@
namespace marian {
namespace mlp {
-enum struct act : int { linear, tanh, logit, ReLU, LeakyReLU, swish };
+enum struct act : int { linear, tanh, logit, ReLU, LeakyReLU, PReLU, swish };
}
}
@@ -129,6 +129,7 @@ public:
case act::logit: return logit(outputs);
case act::ReLU: return relu(outputs);
case act::LeakyReLU: return leakyrelu(outputs);
+ case act::PReLU: return prelu(outputs);
case act::swish: return swish(outputs);
default: return plus(outputs);
}
@@ -188,6 +189,7 @@ public:
case act::logit: return logit(out);
case act::ReLU: return relu(out);
case act::LeakyReLU: return leakyrelu(out);
+ case act::PReLU: return prelu(out);
case act::swish: return swish(out);
default: return out;
}