diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-02-18 06:10:32 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-02-18 06:10:32 +0300 |
commit | 0665363fe8ba6369f280ce4f7baca594b21f4ea8 (patch) | |
tree | ce26dd1147a6acf96d2c8db35f7418cd64e01370 | |
parent | 618f847d94ad65baef1c1614ed241d6e4bea7151 (diff) | |
parent | 51cc7c5c2e6743e96932dc16b3959d55f8619b70 (diff) |
Merge pull request #444 from apaszke/fixes
Improvements for spatial functions
-rw-r--r-- | lib/THCUNN/LogSoftMax.cu | 13 | ||||
-rw-r--r-- | lib/THCUNN/generic/SpatialClassNLLCriterion.cu | 35 |
2 files changed, 32 insertions, 16 deletions
diff --git a/lib/THCUNN/LogSoftMax.cu b/lib/THCUNN/LogSoftMax.cu index 4d7973e..98b7670 100644 --- a/lib/THCUNN/LogSoftMax.cu +++ b/lib/THCUNN/LogSoftMax.cu @@ -21,11 +21,17 @@ __global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input, (width*classSize)*y + (classSize)*x; + T maxInput = input[inputStartIndex]; + for (int i = 1; i < classSize; i++) { + T value = input[inputStartIndex + i]; + maxInput = THCNumerics<T>::ge(maxInput, value) ? maxInput : value; + } + AccumT sum = 0; for (int i = 0; i < classSize; i++) { - sum += THCNumerics<T>::exp(input[inputStartIndex + i]); + sum += THCNumerics<T>::exp(input[inputStartIndex + i] - maxInput); } - sum = AccumT(1) / sum; + T logsum = maxInput + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum)); for (int i = 0; i < classSize; i++) { // calculate output index in torch layout (B x C x H x W) @@ -34,8 +40,7 @@ __global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input, (height*width)*i + (width)*y + x; - output[outputIndex] = ScalarConvert<AccumT, T>::to( - THCNumerics<AccumT>::log(sum * THCNumerics<T>::exp(input[inputStartIndex + i]))); + output[outputIndex] = input[inputStartIndex + i] - logsum; } index += blockDim.x; } diff --git a/lib/THCUNN/generic/SpatialClassNLLCriterion.cu b/lib/THCUNN/generic/SpatialClassNLLCriterion.cu index d9ffc86..6bf1783 100644 --- a/lib/THCUNN/generic/SpatialClassNLLCriterion.cu +++ b/lib/THCUNN/generic/SpatialClassNLLCriterion.cu @@ -2,14 +2,11 @@ #define THC_GENERIC_FILE "generic/SpatialClassNLLCriterion.cu" #else -void THNN_(SpatialClassNLLCriterion_updateOutput)( +void THNN_(SpatialClassNLLCriterion_shapeCheck)( THCState *state, THCTensor *input, THCIndexTensor *target, - THCTensor *output, - bool sizeAverage, - THCTensor *weights, - THCTensor *total_weight) + THCTensor *weights) { THArgCheck(THCIndexTensor_(nDimension)(state, target) == 3, 1, "only batches of spatial targets supported (3D tensors)" \ @@ -18,10 +15,30 @@ void THNN_(SpatialClassNLLCriterion_updateOutput)( THArgCheck(THCTensor_(nDimension)(state, input) == 4, 2, "only batches of spatial inputs supported (4D tensors), " \ "but got input of dimension: %d", THCTensor_(nDimension)(state, input)); + if (THCTensor_(size)(state, input, 0) != THCIndexTensor_(size)(state, target, 0) || + THCTensor_(size)(state, input, 2) != THCIndexTensor_(size)(state, target, 1) || + THCTensor_(size)(state, input, 3) != THCIndexTensor_(size)(state, target, 2)) { + THCDescBuff input_size = THCTensor_(sizeDesc)(state, input); + THCDescBuff target_size = THCIndexTensor_(sizeDesc)(state, target); + THError("input and target batch or spatial sizes don't match: target %s, input %s", + target_size.str, input_size.str); + } if (weights && THCTensor_(nElement)(state, weights) != THCTensor_(size)(state, input, 1)) { THError("weight tensor should be defined either for all or no classes"); } +} + +void THNN_(SpatialClassNLLCriterion_updateOutput)( + THCState *state, + THCTensor *input, + THCIndexTensor *target, + THCTensor *output, + bool sizeAverage, + THCTensor *weights, + THCTensor *total_weight) +{ + THNN_(SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights); if (weights) THCUNN_assertSameGPU(state, 5, input, target, weights, output, total_weight); @@ -77,15 +94,9 @@ void THNN_(SpatialClassNLLCriterion_updateGradInput)( THCTensor *weights, THCTensor *total_weight) { - THArgCheck(THCIndexTensor_(nDimension)(state, target) == 3, 1, - "only batches of spatial targets supported (3D tensors)"); - THArgCheck(THCTensor_(nDimension)(state, input) == 4, 2, - "only batches of spatial inputs supported (4D tensors)"); + THNN_(SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights); THArgCheck(THCTensor_(isContiguous)(state, gradInput), 4, "gradInput must be contiguous"); - if (weights && THCTensor_(nElement)(state, weights) != THCTensor_(size)(state, input, 1)) { - THError("weight tensor should be defined either for all or no classes"); - } if (weights) THCUNN_assertSameGPU(state, 5, weights, input, target, gradInput, total_weight); |