diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-11-28 22:46:26 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-11-28 22:46:26 +0300 |
commit | fdb589f3598ec48d74c53d1b4c6e561dbe91dc48 (patch) | |
tree | fa533877902fffa34b59d90bcbdab7dd2f534a88 | |
parent | cea3855c6b9d4e39bd8a681b5d74f07802cf2e25 (diff) |
Improve error messages/shape check in TemporalMaxPooling.
-rw-r--r-- | lib/THCUNN/generic/TemporalMaxPooling.cu | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/lib/THCUNN/generic/TemporalMaxPooling.cu b/lib/THCUNN/generic/TemporalMaxPooling.cu index d807abc..ccc3900 100644 --- a/lib/THCUNN/generic/TemporalMaxPooling.cu +++ b/lib/THCUNN/generic/TemporalMaxPooling.cu @@ -2,6 +2,24 @@ #define THC_GENERIC_FILE "generic/TemporalMaxPooling.cu" #else +static inline void THNN_(TemporalMaxPooling_shapeCheck)( + THCState *state, + THCTensor *input, + int kW, int dW) { + THArgCheck(kW > 0, 5, + "kernel size should be greater than zero, but got kW: %d", kW); + THArgCheck(dW > 0, 6, + "stride should be greater than zero, but got dW: %d", dW); + + THCUNN_argCheck(state, input->nDimension == 2 || input->nDimension == 3, 2, input, + "2D or 3D (batch mode) tensor expected for input, but got: %s"); + + int dimT = input->nDimension == 3 ? 1 : 0; + THArgCheck(input->size[dimT] >= kW, 2, + "input sequence smaller than kernel size. Got: %d, Expected: %d", + input->size[dimT], kW); +} + void THNN_(TemporalMaxPooling_updateOutput)( THCState *state, THCTensor *input, @@ -23,16 +41,13 @@ void THNN_(TemporalMaxPooling_updateOutput)( THCIndex_t *indices_data; THCUNN_assertSameGPU(state, 3, input, output, indices); - THArgCheck( input->nDimension == 2 || input->nDimension == 3, 2, "2D or 3D(batch mode) tensor expected"); - + THNN_(TemporalMaxPooling_shapeCheck)(state, input, kW, dW); if (input->nDimension == 3) { dimT = 1; dimF = 2; batch = input->size[0]; } - THArgCheck( input->size[dimT] >= kW, 2, "input sequence smaller than kernel size"); - input = THCTensor_(newContiguous)(state, input); input_w = input->size[dimT]; @@ -98,8 +113,7 @@ void THNN_(TemporalMaxPooling_updateGradInput)( THCIndex_t *indices_data; THCUNN_assertSameGPU(state, 4, input, gradOutput, gradInput, indices); - THArgCheck( input->nDimension == 2 || input->nDimension == 3, 2, "2D or 3D(batch mode) tensor expected"); - + THNN_(TemporalMaxPooling_shapeCheck)(state, input, kW, dW); THCTensor_(resizeAs)(state, gradInput, input); THCTensor_(zero)(state, gradInput); @@ -109,8 +123,6 @@ void THNN_(TemporalMaxPooling_updateGradInput)( dimF = 2; batch = input->size[0]; } - THArgCheck( input->size[dimT] >= kW, 2, "input sequence smaller than kernel size"); - gradOutput = THCTensor_(newContiguous)(state, gradOutput); input_w = input->size[dimT]; |