diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-11-29 23:19:56 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-12-02 02:49:48 +0300 |
commit | 0d223082f1bfc62e7cc7df58e1a124568cfabbed (patch) | |
tree | b02d3aed3c4f2de1940eb7af5d2ceec9dd2bf10d | |
parent | bb8e4628fb02bb0b6ae5c5998ef07498322b4e35 (diff) |
Add gradOutput shape checks in temporal modules.
-rw-r--r-- | lib/THCUNN/generic/TemporalMaxPooling.cu | 33 |
1 files changed, 29 insertions, 4 deletions
diff --git a/lib/THCUNN/generic/TemporalMaxPooling.cu b/lib/THCUNN/generic/TemporalMaxPooling.cu index ccc3900..41db31c 100644 --- a/lib/THCUNN/generic/TemporalMaxPooling.cu +++ b/lib/THCUNN/generic/TemporalMaxPooling.cu @@ -5,7 +5,21 @@ static inline void THNN_(TemporalMaxPooling_shapeCheck)( THCState *state, THCTensor *input, + THCTensor *gradOutput, + THCIndexTensor *indices, int kW, int dW) { + int dimT = 0; // Temporal dimension + int dimF = 1; // Feature dimension + int input_w; + int input_n; + int output_w; + int ndims = input->nDimension; + + if (ndims == 3) + { + dimT = 1; + dimF = 2; + } THArgCheck(kW > 0, 5, "kernel size should be greater than zero, but got kW: %d", kW); THArgCheck(dW > 0, 6, @@ -13,11 +27,22 @@ static inline void THNN_(TemporalMaxPooling_shapeCheck)( THCUNN_argCheck(state, input->nDimension == 2 || input->nDimension == 3, 2, input, "2D or 3D (batch mode) tensor expected for input, but got: %s"); - - int dimT = input->nDimension == 3 ? 1 : 0; THArgCheck(input->size[dimT] >= kW, 2, "input sequence smaller than kernel size. Got: %d, Expected: %d", input->size[dimT], kW); + + input_w = input->size[dimT]; + input_n = input->size[dimF]; + output_w = (input_w - kW) / dW + 1; + + if (gradOutput != NULL) { + THCUNN_check_dim_size(state, gradOutput, ndims, dimT, output_w); + THCUNN_check_dim_size(state, gradOutput, ndims, dimF, input_n) + } + if (indices != NULL) { + THCUNN_check_dim_size_indices(state, indices, ndims, dimT, output_w); + THCUNN_check_dim_size_indices(state, indices, ndims, dimF, input_n); + } } void THNN_(TemporalMaxPooling_updateOutput)( @@ -41,7 +66,7 @@ void THNN_(TemporalMaxPooling_updateOutput)( THCIndex_t *indices_data; THCUNN_assertSameGPU(state, 3, input, output, indices); - THNN_(TemporalMaxPooling_shapeCheck)(state, input, kW, dW); + THNN_(TemporalMaxPooling_shapeCheck)(state, input, NULL, NULL, kW, dW); if (input->nDimension == 3) { dimT = 1; @@ -113,7 +138,7 @@ void THNN_(TemporalMaxPooling_updateGradInput)( THCIndex_t *indices_data; THCUNN_assertSameGPU(state, 4, input, gradOutput, gradInput, indices); - THNN_(TemporalMaxPooling_shapeCheck)(state, input, kW, dW); + THNN_(TemporalMaxPooling_shapeCheck)(state, input, gradOutput, indices, kW, dW); THCTensor_(resizeAs)(state, gradInput, input); THCTensor_(zero)(state, gradInput); |