diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2017-02-18 00:02:25 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2017-02-18 00:05:45 +0300 |
commit | 51cc7c5c2e6743e96932dc16b3959d55f8619b70 (patch) | |
tree | ce26dd1147a6acf96d2c8db35f7418cd64e01370 | |
parent | 6726d1a6a3cfbe9c1b5d79361dda952aa9e663a3 (diff) |
Use a more stable formula for spatial LogSoftMax
-rw-r--r-- | lib/THCUNN/LogSoftMax.cu | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/lib/THCUNN/LogSoftMax.cu b/lib/THCUNN/LogSoftMax.cu index 4d7973e..98b7670 100644 --- a/lib/THCUNN/LogSoftMax.cu +++ b/lib/THCUNN/LogSoftMax.cu @@ -21,11 +21,17 @@ __global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input, (width*classSize)*y + (classSize)*x; + T maxInput = input[inputStartIndex]; + for (int i = 1; i < classSize; i++) { + T value = input[inputStartIndex + i]; + maxInput = THCNumerics<T>::ge(maxInput, value) ? maxInput : value; + } + AccumT sum = 0; for (int i = 0; i < classSize; i++) { - sum += THCNumerics<T>::exp(input[inputStartIndex + i]); + sum += THCNumerics<T>::exp(input[inputStartIndex + i] - maxInput); } - sum = AccumT(1) / sum; + T logsum = maxInput + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum)); for (int i = 0; i < classSize; i++) { // calculate output index in torch layout (B x C x H x W) @@ -34,8 +40,7 @@ __global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input, (height*width)*i + (width)*y + x; - output[outputIndex] = ScalarConvert<AccumT, T>::to( - THCNumerics<AccumT>::log(sum * THCNumerics<T>::exp(input[inputStartIndex + i]))); + output[outputIndex] = input[inputStartIndex + i] - logsum; } index += blockDim.x; } |