Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-02-18 06:10:32 +0300
committerGitHub <noreply@github.com>2017-02-18 06:10:32 +0300
commit0665363fe8ba6369f280ce4f7baca594b21f4ea8 (patch)
treece26dd1147a6acf96d2c8db35f7418cd64e01370
parent618f847d94ad65baef1c1614ed241d6e4bea7151 (diff)
parent51cc7c5c2e6743e96932dc16b3959d55f8619b70 (diff)
Merge pull request #444 from apaszke/fixes
Improvements for spatial functions
-rw-r--r--lib/THCUNN/LogSoftMax.cu13
-rw-r--r--lib/THCUNN/generic/SpatialClassNLLCriterion.cu35
2 files changed, 32 insertions, 16 deletions
diff --git a/lib/THCUNN/LogSoftMax.cu b/lib/THCUNN/LogSoftMax.cu
index 4d7973e..98b7670 100644
--- a/lib/THCUNN/LogSoftMax.cu
+++ b/lib/THCUNN/LogSoftMax.cu
@@ -21,11 +21,17 @@ __global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input,
(width*classSize)*y +
(classSize)*x;
+ T maxInput = input[inputStartIndex];
+ for (int i = 1; i < classSize; i++) {
+ T value = input[inputStartIndex + i];
+ maxInput = THCNumerics<T>::ge(maxInput, value) ? maxInput : value;
+ }
+
AccumT sum = 0;
for (int i = 0; i < classSize; i++) {
- sum += THCNumerics<T>::exp(input[inputStartIndex + i]);
+ sum += THCNumerics<T>::exp(input[inputStartIndex + i] - maxInput);
}
- sum = AccumT(1) / sum;
+ T logsum = maxInput + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum));
for (int i = 0; i < classSize; i++) {
// calculate output index in torch layout (B x C x H x W)
@@ -34,8 +40,7 @@ __global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input,
(height*width)*i +
(width)*y +
x;
- output[outputIndex] = ScalarConvert<AccumT, T>::to(
- THCNumerics<AccumT>::log(sum * THCNumerics<T>::exp(input[inputStartIndex + i])));
+ output[outputIndex] = input[inputStartIndex + i] - logsum;
}
index += blockDim.x;
}
diff --git a/lib/THCUNN/generic/SpatialClassNLLCriterion.cu b/lib/THCUNN/generic/SpatialClassNLLCriterion.cu
index d9ffc86..6bf1783 100644
--- a/lib/THCUNN/generic/SpatialClassNLLCriterion.cu
+++ b/lib/THCUNN/generic/SpatialClassNLLCriterion.cu
@@ -2,14 +2,11 @@
#define THC_GENERIC_FILE "generic/SpatialClassNLLCriterion.cu"
#else
-void THNN_(SpatialClassNLLCriterion_updateOutput)(
+void THNN_(SpatialClassNLLCriterion_shapeCheck)(
THCState *state,
THCTensor *input,
THCIndexTensor *target,
- THCTensor *output,
- bool sizeAverage,
- THCTensor *weights,
- THCTensor *total_weight)
+ THCTensor *weights)
{
THArgCheck(THCIndexTensor_(nDimension)(state, target) == 3, 1,
"only batches of spatial targets supported (3D tensors)" \
@@ -18,10 +15,30 @@ void THNN_(SpatialClassNLLCriterion_updateOutput)(
THArgCheck(THCTensor_(nDimension)(state, input) == 4, 2,
"only batches of spatial inputs supported (4D tensors), " \
"but got input of dimension: %d", THCTensor_(nDimension)(state, input));
+ if (THCTensor_(size)(state, input, 0) != THCIndexTensor_(size)(state, target, 0) ||
+ THCTensor_(size)(state, input, 2) != THCIndexTensor_(size)(state, target, 1) ||
+ THCTensor_(size)(state, input, 3) != THCIndexTensor_(size)(state, target, 2)) {
+ THCDescBuff input_size = THCTensor_(sizeDesc)(state, input);
+ THCDescBuff target_size = THCIndexTensor_(sizeDesc)(state, target);
+ THError("input and target batch or spatial sizes don't match: target %s, input %s",
+ target_size.str, input_size.str);
+ }
if (weights && THCTensor_(nElement)(state, weights) != THCTensor_(size)(state, input, 1)) {
THError("weight tensor should be defined either for all or no classes");
}
+}
+
+void THNN_(SpatialClassNLLCriterion_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCIndexTensor *target,
+ THCTensor *output,
+ bool sizeAverage,
+ THCTensor *weights,
+ THCTensor *total_weight)
+{
+ THNN_(SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights);
if (weights)
THCUNN_assertSameGPU(state, 5, input, target, weights, output, total_weight);
@@ -77,15 +94,9 @@ void THNN_(SpatialClassNLLCriterion_updateGradInput)(
THCTensor *weights,
THCTensor *total_weight)
{
- THArgCheck(THCIndexTensor_(nDimension)(state, target) == 3, 1,
- "only batches of spatial targets supported (3D tensors)");
- THArgCheck(THCTensor_(nDimension)(state, input) == 4, 2,
- "only batches of spatial inputs supported (4D tensors)");
+ THNN_(SpatialClassNLLCriterion_shapeCheck)(state, input, target, weights);
THArgCheck(THCTensor_(isContiguous)(state, gradInput), 4,
"gradInput must be contiguous");
- if (weights && THCTensor_(nElement)(state, weights) != THCTensor_(size)(state, input, 1)) {
- THError("weight tensor should be defined either for all or no classes");
- }
if (weights)
THCUNN_assertSameGPU(state, 5, weights, input, target, gradInput, total_weight);