Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlykhan Tejani <alykhan.tejani@gmail.com>2017-07-24 21:58:43 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-27 05:19:48 +0300
commit8980c40b0c4fce7736e9837fecbe2c7ac437d34b (patch)
tree1dcaa144cfffc8ce5ef1fce96e8b894d4f8e9f79
parent283539e161cf73955de64ff9a596b9eb8358dfd2 (diff)
Add numerically stable logsigmoid
-rw-r--r--lib/THCUNN/LogSigmoid.cu48
1 files changed, 38 insertions, 10 deletions
diff --git a/lib/THCUNN/LogSigmoid.cu b/lib/THCUNN/LogSigmoid.cu
index bb9753a..dbdfea5 100644
--- a/lib/THCUNN/LogSigmoid.cu
+++ b/lib/THCUNN/LogSigmoid.cu
@@ -7,7 +7,9 @@ template <typename T>
struct logSigmoid_updateOutput_functor
{
__device__ void operator()(T *output, const T *input) const {
- *output = -THCNumerics<T>::log(1.f + THCNumerics<T>::exp(- *input));
+ const T max = fmaxType(0.f, - *input);
+ const T z = THCNumerics<T>::exp(-max) + THCNumerics<T>::exp(-*input -max);
+ *output = -(max + THCNumerics<T>::log(z));
}
};
@@ -15,8 +17,15 @@ template <typename T>
struct logSigmoid_updateGradInput_functor
{
__device__ void operator()(T *gradInput, const T *input, const T *gradOutput) const {
- const T z = THCNumerics<T>::exp(- *input);
- *gradInput = *gradOutput * z / (1.f + z);
+ const T max = fmaxType(0.f, -*input);
+ const T z = THCNumerics<T>::exp(-max) + THCNumerics<T>::exp(-*input -max);
+ T max_deriv = 0.f;
+ T sign = -1.f;
+ if (*input < 0.f){
+ max_deriv = -1.f;
+ sign = 1.f;
+ }
+ *gradInput = *gradOutput * (-max_deriv - sign*((z - 1.f)/z));
}
};
@@ -25,11 +34,14 @@ template <>
struct logSigmoid_updateOutput_functor<half> {
__device__ __forceinline__ void operator()(half* output, const half *input) const {
#ifdef CUDA_HALF_INSTRUCTIONS
- const half one = __float2half(1.f);
- *output = __hneg(THCNumerics<half>::log(one + THCNumerics<half>::exp(__hneg(*input))));
+ const half max = fmaxType(__float2half(0.f), __hneg(*input));
+ const half z = THCNumerics<half>::exp(__hneg(max)) + THCNumerics<half>::exp(__hneg(*input) - max);
+ *output = __hneg(max + THCNumerics<half>::log(z));
#else
float in = __half2float(*input);
- *output = __float2half(-THCNumerics<float>::log(1.f + THCNumerics<float>::exp(-in)));
+ float max = fmaxType(0.f, -in);
+ float z = THCNumerics<float>::exp(-max) + THCNumerics<float>::exp(-in - max);
+ *output = __float2half(-(max + THCNumerics<float>::log(z)));
#endif
}
};
@@ -39,12 +51,28 @@ struct logSigmoid_updateGradInput_functor<half> {
__device__ __forceinline__ void operator()(half* gradInput, const half *input, const half *gradOutput) const {
#ifdef CUDA_HALF_INSTRUCTIONS
const half one = __float2half(1.f);
- const half in_exp = THCNumerics<half>::exp(__hneg(*input));
- *gradInput = hdiv(__hmul(*gradOutput, in_exp), __hadd(one, in_exp));
+ const half zero = __float2half(0.f);
+ const half max = fmaxType(zero, __hneg(*input));
+ const half z = THCNumerics<half>::exp(__hneg(max)) + THCNumerics<half>::exp(__hneg(*input) - max);
+ half max_deriv = zero;
+ half sign = __hneg(one);
+ if(*input < zero){
+ max_deriv = __hneg(one);
+ sign = one;
+ }
+ *gradInput = __hmul(*gradOutput, (__hneg(max_deriv) - __hmul(sign, __hdiv(z - one, z))));
#else
- const float in_exp = THCNumerics<float>::exp(-(__half2float(*input)));
+ const float in = __half2float(*input);
+ const float max = fmaxType(0.f, -in);
+ const float z = THCNumerics<float>::exp(-max) + THCNumerics<float>::exp(-in - max);
const float go = __half2float(*gradOutput);
- *gradInput = __float2half(go * in_exp / (1.f + in_exp));
+ float max_deriv = 0.f;
+ float sign = -1.f;
+ if(in < 0.f){
+ max_deriv = -1.f;
+ sign = 1.f;
+ }
+ *gradInput = __float2half(go * (-max_deriv - sign*((z - 1.f)/z)));
#endif
}
};