Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2017-02-18 00:02:25 +0300
committerAdam Paszke <adam.paszke@gmail.com>2017-02-18 00:05:45 +0300
commit51cc7c5c2e6743e96932dc16b3959d55f8619b70 (patch)
treece26dd1147a6acf96d2c8db35f7418cd64e01370
parent6726d1a6a3cfbe9c1b5d79361dda952aa9e663a3 (diff)
Use a more stable formula for spatial LogSoftMax
-rw-r--r--lib/THCUNN/LogSoftMax.cu13
1 files changed, 9 insertions, 4 deletions
diff --git a/lib/THCUNN/LogSoftMax.cu b/lib/THCUNN/LogSoftMax.cu
index 4d7973e..98b7670 100644
--- a/lib/THCUNN/LogSoftMax.cu
+++ b/lib/THCUNN/LogSoftMax.cu
@@ -21,11 +21,17 @@ __global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input,
(width*classSize)*y +
(classSize)*x;
+ T maxInput = input[inputStartIndex];
+ for (int i = 1; i < classSize; i++) {
+ T value = input[inputStartIndex + i];
+ maxInput = THCNumerics<T>::ge(maxInput, value) ? maxInput : value;
+ }
+
AccumT sum = 0;
for (int i = 0; i < classSize; i++) {
- sum += THCNumerics<T>::exp(input[inputStartIndex + i]);
+ sum += THCNumerics<T>::exp(input[inputStartIndex + i] - maxInput);
}
- sum = AccumT(1) / sum;
+ T logsum = maxInput + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum));
for (int i = 0; i < classSize; i++) {
// calculate output index in torch layout (B x C x H x W)
@@ -34,8 +40,7 @@ __global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input,
(height*width)*i +
(width)*y +
x;
- output[outputIndex] = ScalarConvert<AccumT, T>::to(
- THCNumerics<AccumT>::log(sum * THCNumerics<T>::exp(input[inputStartIndex + i])));
+ output[outputIndex] = input[inputStartIndex + i] - logsum;
}
index += blockDim.x;
}