diff options
author | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2017-10-29 20:57:53 +0300 |
---|---|---|
committer | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2017-10-29 20:57:53 +0300 |
commit | 857289d1bc0afa14d1338f7b6a29cb614b03bcac (patch) | |
tree | fdc3d9926c8b9bd5a020297de8a6634d665e395e | |
parent | a6d4f2a9ab1ed4b7f559715e9a5b70082f5efd32 (diff) |
Add PReLU
-rw-r--r-- | src/graph/expression_operators.cu | 8 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 3 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 53 | ||||
-rw-r--r-- | src/kernels/thrust_functions.h | 34 | ||||
-rw-r--r-- | src/layers/generic.h | 4 |
5 files changed, 97 insertions, 5 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index ee059566..c75c3ea0 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -24,6 +24,10 @@ Expr leakyrelu(Expr a) { return Expression<LeakyReLUNodeOp>(a); } +Expr prelu(Expr a, float alpha) { + return Expression<PReLUNodeOp>(alpha, a); +} + Expr log(Expr a) { return Expression<LogNodeOp>(a); }; @@ -214,6 +218,10 @@ Expr leakyrelu(const std::vector<Expr>&) { ABORT("Not implemented"); } +Expr prelu(const std::vector<Expr>&, float alpha) { + ABORT("Not implemented"); +} + Expr sqrt(Expr a, float eps) { return Expression<SqrtNodeOp>(a, eps); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index d17fe8b3..b53fb6e4 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -27,6 +27,9 @@ Expr relu(const std::vector<Expr>&); Expr leakyrelu(Expr a); Expr leakyrelu(const std::vector<Expr>&); +Expr prelu(Expr a, float alpha = 0.01); +Expr prelu(const std::vector<Expr>&, float alpha = 0.01); + Expr log(Expr a); Expr exp(Expr a); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 93c12a3c..819192dc 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -243,8 +243,8 @@ struct ReLUNodeOp : public UnaryNodeOp { * \f[ * f(x) = * \begin{cases} - * 0.01 & \text{if } x < 0 \\ - * x & \text{if } x \geq 0 + * 0.01 & \text{if } x \leq 0 \\ + * x & \text{if } x > 0 * \end{cases} * \f] * @@ -252,8 +252,8 @@ struct ReLUNodeOp : public UnaryNodeOp { * \f[ * f^\prime(x) = * \begin{cases} - * 0.01 & \text{if } x < 0 \\ - * 1 & \text{if } x \geq 0 + * 0.01 & \text{if } x \leq 0 \\ + * 1 & \text{if } x > 0 * \end{cases} * \f] */ @@ -274,6 +274,51 @@ struct LeakyReLUNodeOp : public UnaryNodeOp { }; /** + * Represents a <a + * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">parametric + * rectified linear unit</a> node in an expression graph. + * For \f$ \alpha = 0.01 \f$ (the default value) it is equivalent to Leaky + * ReLU. + * + * This node implements the activation function: + * \f[ + * f(x, \alpha) = + * \begin{cases} + * \alpha x & \text{if } x \leq 0 \\ + * x & \text{if } x > 0 + * \end{cases} + * \f] + * + * and its derivative: + * \f[ + * f^\prime(x, \alpha) = + * \begin{cases} + * \alpha & \text{if } x \leq 0 \\ + * 1 & \text{if } x > 0 + * \end{cases} + * \f] + */ +struct PReLUNodeOp : public UnaryNodeOp { + template <typename... Args> + PReLUNodeOp(float alpha, Args... args) + : UnaryNodeOp(args...), alpha_(alpha) {} + + NodeOps forwardOps() { + return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))}; + } + + NodeOps backwardOps() { + return {NodeOp( + Add(_1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))}; + } + + const std::string type() { return "PReLU"; } + +private: + float alpha_{0.01}; +}; + +/** * Represents a <a href="https://arxiv.org/pdf/1710.05941.pdf">swish</a> node * in an expression graph. * diff --git a/src/kernels/thrust_functions.h b/src/kernels/thrust_functions.h index 3c46349d..d33f24b7 100644 --- a/src/kernels/thrust_functions.h +++ b/src/kernels/thrust_functions.h @@ -141,6 +141,40 @@ __host__ __device__ //******************************************************************* template <typename T> +struct binary_prelu : public thrust::binary_function<T, T, T> { + __host__ __device__ T operator()(const T &x, const T &alpha) const { + return x > 0.0f ? x : alpha * x; + } +}; + +template <typename T1, typename T2> +__host__ __device__ actor<composite<binary_operator<binary_prelu>, + actor<T1>, + typename as_actor<T2>::type>> +PReLU(const actor<T1> &_1, const T2 &_2) { + return compose( + binary_operator<binary_prelu>(), make_actor(_1), make_actor(_2)); +} + +template <typename T> +struct binary_preluback : public thrust::binary_function<T, T, T> { + __host__ __device__ T operator()(const T &x, const T &alpha) const { + return x > 0.0f ? 1.0f : alpha; + } +}; + +template <typename T1, typename T2> +__host__ __device__ actor<composite<binary_operator<binary_preluback>, + actor<T1>, + typename as_actor<T2>::type>> +PReLUback(const actor<T1> &_1, const T2 &_2) { + return compose( + binary_operator<binary_preluback>(), make_actor(_1), make_actor(_2)); +} + +//******************************************************************* + +template <typename T> __host__ __device__ int sgn(T val) { return (float(0) < val) - (val < float(0)); } diff --git a/src/layers/generic.h b/src/layers/generic.h index ed1ab222..dea129d6 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -9,7 +9,7 @@ namespace marian { namespace mlp { -enum struct act : int { linear, tanh, logit, ReLU, LeakyReLU, swish }; +enum struct act : int { linear, tanh, logit, ReLU, LeakyReLU, PReLU, swish }; } } @@ -129,6 +129,7 @@ public: case act::logit: return logit(outputs); case act::ReLU: return relu(outputs); case act::LeakyReLU: return leakyrelu(outputs); + case act::PReLU: return prelu(outputs); case act::swish: return swish(outputs); default: return plus(outputs); } @@ -188,6 +189,7 @@ public: case act::logit: return logit(out); case act::ReLU: return relu(out); case act::LeakyReLU: return leakyrelu(out); + case act::PReLU: return prelu(out); case act::swish: return swish(out); default: return out; } |