diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-11-28 22:16:05 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-11-28 22:19:00 +0300 |
commit | 0386f1354d744c98b7bbdca8dbe8627640bf2899 (patch) | |
tree | 8667ab18d87f4afeea17918cdb8c87f8a534ac1f | |
parent | 8161422ad0393f4653f996db127a2aa9fe9a7e5d (diff) |
Improve error messages/shape checks for temporal modules.
-rw-r--r-- | lib/THNN/generic/TemporalMaxPooling.c | 23 | ||||
-rw-r--r-- | lib/THNN/generic/TemporalSubSampling.c | 31 |
2 files changed, 48 insertions, 6 deletions
diff --git a/lib/THNN/generic/TemporalMaxPooling.c b/lib/THNN/generic/TemporalMaxPooling.c index 0a2f004..462c170 100644 --- a/lib/THNN/generic/TemporalMaxPooling.c +++ b/lib/THNN/generic/TemporalMaxPooling.c @@ -2,6 +2,25 @@ #define TH_GENERIC_FILE "generic/TemporalMaxPooling.c" #else +static inline void THNN_(TemporalMaxPooling_shapeCheck)( + THNNState *state, + THTensor *input, + int kW, + int dW) { + THArgCheck(kW > 0, 5, + "kernel size should be greater than zero, but got kW: %d", kW); + THArgCheck(dW > 0, 6, + "stride should be greater than zero, but got dW: %d", dW); + + THNN_ARGCHECK(input->nDimension == 2 || input->nDimension == 3, 2, input, + "2D or 3D (batch mode) tensor expected for input, but got: %s"); + + int dimS = input->nDimension == 3 ? 1 : 0; + THArgCheck(input->size[dimS] >= kW, 2, + "input sequence smaller than kernel size. Got: %d, Expected: %d", + input->size[dimS], kW); +} + void THNN_(TemporalMaxPooling_updateOutput)( THNNState *state, THTensor *input, @@ -23,14 +42,13 @@ void THNN_(TemporalMaxPooling_updateOutput)( int dimS = 0; // sequence dimension int dimF = 1; // feature dimension - THArgCheck(input->nDimension == 2 || input->nDimension == 3, 2, "2D or 3D(batch mode) tensor expected"); + THNN_(TemporalMaxPooling_shapeCheck)(state, input, kW, dW); if (input->nDimension == 3) { dimS = 1; dimF = 2; } - THArgCheck(input->size[dimS] >= kW, 2, "input sequence smaller than kernel size"); /* sizes */ niframe = input->size[dimS]; @@ -159,6 +177,7 @@ void THNN_(TemporalMaxPooling_updateGradInput)( long t, y; + THNN_(TemporalMaxPooling_shapeCheck)(state, input, kW, dW); /* get contiguous gradOutput */ gradOutput = THTensor_(newContiguous)(gradOutput); diff --git a/lib/THNN/generic/TemporalSubSampling.c b/lib/THNN/generic/TemporalSubSampling.c index 7fa323d..e53cc59 100644 --- a/lib/THNN/generic/TemporalSubSampling.c +++ b/lib/THNN/generic/TemporalSubSampling.c @@ -2,6 +2,29 @@ #define TH_GENERIC_FILE "generic/TemporalSubSampling.c" #else +static inline void THNN_(TemporalSubSampling_shapeCheck)( + THNNState *state, + THTensor *input, + int kW, + int dW, + int *inputFrameSize) { + THArgCheck(kW > 0, 6, + "kernel size should be greater than zero, but got kW: %d", kW); + THArgCheck(dW > 0, 7, + "stride should be greater than zero, but got dW: %d", dW); + + THNN_ARGCHECK(input->nDimension == 2, 2, input, + "2D or 3D (batch mode) tensor expected for input, but got: %s"); + if (inputFrameSize != NULL) { + THArgCheck( input->size[1] == *inputFrameSize, 2, + "invalid input frame size. Got: %d, Expected: %d", + input->size[1], *inputFrameSize); + } + THArgCheck( input->size[0] >= kW, 2, + "input sequence smaller than kernel size. Got %d, Expected: %d", + input->size[0], kW); +} + void THNN_(TemporalSubSampling_updateOutput)( THNNState *state, THTensor *input, @@ -16,9 +39,7 @@ void THNN_(TemporalSubSampling_updateOutput)( int nInputFrame, nOutputFrame; long k; - THArgCheck( input->nDimension == 2, 2, "2D tensor expected"); - THArgCheck( input->size[1] == inputFrameSize, 2, "invalid input frame size"); - THArgCheck( input->size[0] >= kW, 2, "input sequence smaller than kernel size"); + THNN_(TemporalSubSampling_shapeCheck)(state, input, kW, dW, &inputFrameSize); outputFrame = THTensor_(new)(); inputWindow = THTensor_(new)(); @@ -57,6 +78,8 @@ void THNN_(TemporalSubSampling_updateGradInput)( THTensor *gradInputWindow, *buffer, *kwunit; long k; + THNN_(TemporalSubSampling_shapeCheck)(state, input, kW, dW, NULL); + gradOutputFrame = THTensor_(new)(); gradInputWindow = THTensor_(new)(); buffer = THTensor_(new)(); @@ -94,7 +117,7 @@ void THNN_(TemporalSubSampling_accGradParameters)( THTensor *inputWindow, *buffer; long k; - + THNN_(TemporalSubSampling_shapeCheck)(state, input, kW, dW, NULL); gradOutputFrame = THTensor_(new)(); inputWindow = THTensor_(new)(); buffer = THTensor_(new)(); |