Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-11-28 22:16:05 +0300
committerGregory Chanan <gchanan@fb.com>2016-11-28 22:19:00 +0300
commit0386f1354d744c98b7bbdca8dbe8627640bf2899 (patch)
tree8667ab18d87f4afeea17918cdb8c87f8a534ac1f
parent8161422ad0393f4653f996db127a2aa9fe9a7e5d (diff)
Improve error messages/shape checks for temporal modules.
-rw-r--r--lib/THNN/generic/TemporalMaxPooling.c23
-rw-r--r--lib/THNN/generic/TemporalSubSampling.c31
2 files changed, 48 insertions, 6 deletions
diff --git a/lib/THNN/generic/TemporalMaxPooling.c b/lib/THNN/generic/TemporalMaxPooling.c
index 0a2f004..462c170 100644
--- a/lib/THNN/generic/TemporalMaxPooling.c
+++ b/lib/THNN/generic/TemporalMaxPooling.c
@@ -2,6 +2,25 @@
#define TH_GENERIC_FILE "generic/TemporalMaxPooling.c"
#else
+static inline void THNN_(TemporalMaxPooling_shapeCheck)(
+ THNNState *state,
+ THTensor *input,
+ int kW,
+ int dW) {
+ THArgCheck(kW > 0, 5,
+ "kernel size should be greater than zero, but got kW: %d", kW);
+ THArgCheck(dW > 0, 6,
+ "stride should be greater than zero, but got dW: %d", dW);
+
+ THNN_ARGCHECK(input->nDimension == 2 || input->nDimension == 3, 2, input,
+ "2D or 3D (batch mode) tensor expected for input, but got: %s");
+
+ int dimS = input->nDimension == 3 ? 1 : 0;
+ THArgCheck(input->size[dimS] >= kW, 2,
+ "input sequence smaller than kernel size. Got: %d, Expected: %d",
+ input->size[dimS], kW);
+}
+
void THNN_(TemporalMaxPooling_updateOutput)(
THNNState *state,
THTensor *input,
@@ -23,14 +42,13 @@ void THNN_(TemporalMaxPooling_updateOutput)(
int dimS = 0; // sequence dimension
int dimF = 1; // feature dimension
- THArgCheck(input->nDimension == 2 || input->nDimension == 3, 2, "2D or 3D(batch mode) tensor expected");
+ THNN_(TemporalMaxPooling_shapeCheck)(state, input, kW, dW);
if (input->nDimension == 3)
{
dimS = 1;
dimF = 2;
}
- THArgCheck(input->size[dimS] >= kW, 2, "input sequence smaller than kernel size");
/* sizes */
niframe = input->size[dimS];
@@ -159,6 +177,7 @@ void THNN_(TemporalMaxPooling_updateGradInput)(
long t, y;
+ THNN_(TemporalMaxPooling_shapeCheck)(state, input, kW, dW);
/* get contiguous gradOutput */
gradOutput = THTensor_(newContiguous)(gradOutput);
diff --git a/lib/THNN/generic/TemporalSubSampling.c b/lib/THNN/generic/TemporalSubSampling.c
index 7fa323d..e53cc59 100644
--- a/lib/THNN/generic/TemporalSubSampling.c
+++ b/lib/THNN/generic/TemporalSubSampling.c
@@ -2,6 +2,29 @@
#define TH_GENERIC_FILE "generic/TemporalSubSampling.c"
#else
+static inline void THNN_(TemporalSubSampling_shapeCheck)(
+ THNNState *state,
+ THTensor *input,
+ int kW,
+ int dW,
+ int *inputFrameSize) {
+ THArgCheck(kW > 0, 6,
+ "kernel size should be greater than zero, but got kW: %d", kW);
+ THArgCheck(dW > 0, 7,
+ "stride should be greater than zero, but got dW: %d", dW);
+
+ THNN_ARGCHECK(input->nDimension == 2, 2, input,
+ "2D or 3D (batch mode) tensor expected for input, but got: %s");
+ if (inputFrameSize != NULL) {
+ THArgCheck( input->size[1] == *inputFrameSize, 2,
+ "invalid input frame size. Got: %d, Expected: %d",
+ input->size[1], *inputFrameSize);
+ }
+ THArgCheck( input->size[0] >= kW, 2,
+ "input sequence smaller than kernel size. Got %d, Expected: %d",
+ input->size[0], kW);
+}
+
void THNN_(TemporalSubSampling_updateOutput)(
THNNState *state,
THTensor *input,
@@ -16,9 +39,7 @@ void THNN_(TemporalSubSampling_updateOutput)(
int nInputFrame, nOutputFrame;
long k;
- THArgCheck( input->nDimension == 2, 2, "2D tensor expected");
- THArgCheck( input->size[1] == inputFrameSize, 2, "invalid input frame size");
- THArgCheck( input->size[0] >= kW, 2, "input sequence smaller than kernel size");
+ THNN_(TemporalSubSampling_shapeCheck)(state, input, kW, dW, &inputFrameSize);
outputFrame = THTensor_(new)();
inputWindow = THTensor_(new)();
@@ -57,6 +78,8 @@ void THNN_(TemporalSubSampling_updateGradInput)(
THTensor *gradInputWindow, *buffer, *kwunit;
long k;
+ THNN_(TemporalSubSampling_shapeCheck)(state, input, kW, dW, NULL);
+
gradOutputFrame = THTensor_(new)();
gradInputWindow = THTensor_(new)();
buffer = THTensor_(new)();
@@ -94,7 +117,7 @@ void THNN_(TemporalSubSampling_accGradParameters)(
THTensor *inputWindow, *buffer;
long k;
-
+ THNN_(TemporalSubSampling_shapeCheck)(state, input, kW, dW, NULL);
gradOutputFrame = THTensor_(new)();
inputWindow = THTensor_(new)();
buffer = THTensor_(new)();