Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-11-29 23:19:56 +0300
committerGregory Chanan <gchanan@fb.com>2016-12-02 02:49:48 +0300
commit0d223082f1bfc62e7cc7df58e1a124568cfabbed (patch)
treeb02d3aed3c4f2de1940eb7af5d2ceec9dd2bf10d
parentbb8e4628fb02bb0b6ae5c5998ef07498322b4e35 (diff)
Add gradOutput shape checks in temporal modules.
-rw-r--r--lib/THCUNN/generic/TemporalMaxPooling.cu33
1 files changed, 29 insertions, 4 deletions
diff --git a/lib/THCUNN/generic/TemporalMaxPooling.cu b/lib/THCUNN/generic/TemporalMaxPooling.cu
index ccc3900..41db31c 100644
--- a/lib/THCUNN/generic/TemporalMaxPooling.cu
+++ b/lib/THCUNN/generic/TemporalMaxPooling.cu
@@ -5,7 +5,21 @@
static inline void THNN_(TemporalMaxPooling_shapeCheck)(
THCState *state,
THCTensor *input,
+ THCTensor *gradOutput,
+ THCIndexTensor *indices,
int kW, int dW) {
+ int dimT = 0; // Temporal dimension
+ int dimF = 1; // Feature dimension
+ int input_w;
+ int input_n;
+ int output_w;
+ int ndims = input->nDimension;
+
+ if (ndims == 3)
+ {
+ dimT = 1;
+ dimF = 2;
+ }
THArgCheck(kW > 0, 5,
"kernel size should be greater than zero, but got kW: %d", kW);
THArgCheck(dW > 0, 6,
@@ -13,11 +27,22 @@ static inline void THNN_(TemporalMaxPooling_shapeCheck)(
THCUNN_argCheck(state, input->nDimension == 2 || input->nDimension == 3, 2, input,
"2D or 3D (batch mode) tensor expected for input, but got: %s");
-
- int dimT = input->nDimension == 3 ? 1 : 0;
THArgCheck(input->size[dimT] >= kW, 2,
"input sequence smaller than kernel size. Got: %d, Expected: %d",
input->size[dimT], kW);
+
+ input_w = input->size[dimT];
+ input_n = input->size[dimF];
+ output_w = (input_w - kW) / dW + 1;
+
+ if (gradOutput != NULL) {
+ THCUNN_check_dim_size(state, gradOutput, ndims, dimT, output_w);
+ THCUNN_check_dim_size(state, gradOutput, ndims, dimF, input_n)
+ }
+ if (indices != NULL) {
+ THCUNN_check_dim_size_indices(state, indices, ndims, dimT, output_w);
+ THCUNN_check_dim_size_indices(state, indices, ndims, dimF, input_n);
+ }
}
void THNN_(TemporalMaxPooling_updateOutput)(
@@ -41,7 +66,7 @@ void THNN_(TemporalMaxPooling_updateOutput)(
THCIndex_t *indices_data;
THCUNN_assertSameGPU(state, 3, input, output, indices);
- THNN_(TemporalMaxPooling_shapeCheck)(state, input, kW, dW);
+ THNN_(TemporalMaxPooling_shapeCheck)(state, input, NULL, NULL, kW, dW);
if (input->nDimension == 3)
{
dimT = 1;
@@ -113,7 +138,7 @@ void THNN_(TemporalMaxPooling_updateGradInput)(
THCIndex_t *indices_data;
THCUNN_assertSameGPU(state, 4, input, gradOutput, gradInput, indices);
- THNN_(TemporalMaxPooling_shapeCheck)(state, input, kW, dW);
+ THNN_(TemporalMaxPooling_shapeCheck)(state, input, gradOutput, indices, kW, dW);
THCTensor_(resizeAs)(state, gradInput, input);
THCTensor_(zero)(state, gradInput);