Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-12-05 20:26:28 +0300
committerGregory Chanan <gchanan@fb.com>2016-12-06 20:09:38 +0300
commitf6653730509349ee4760ef112fb2ae80fe10325e (patch)
treeae5729a767a69fb42f69709733d3b78bbf541963
parent821e0e7cb5bf43cdeb4b284f8d7b770f7f8ec16d (diff)
Improve gradOutput checks for VolumetricReplicationPadding.
-rw-r--r--lib/THNN/generic/VolumetricReplicationPadding.c7
1 files changed, 7 insertions, 0 deletions
diff --git a/lib/THNN/generic/VolumetricReplicationPadding.c b/lib/THNN/generic/VolumetricReplicationPadding.c
index 8d9dda3..4d8993e 100644
--- a/lib/THNN/generic/VolumetricReplicationPadding.c
+++ b/lib/THNN/generic/VolumetricReplicationPadding.c
@@ -12,6 +12,8 @@ static inline void THNN_(VolumetricReplicationPadding_shapeCheck)(
int dimw = 3;
int dimh = 2;
int dimd = 1;
+ int dimslices = 0;
+ long nslices;
long idepth;
long iheight;
long iwidth;
@@ -27,9 +29,11 @@ static inline void THNN_(VolumetricReplicationPadding_shapeCheck)(
dimw++;
dimh++;
dimd++;
+ dimslices++;
}
/* sizes */
+ nslices = input->size[dimslices];
idepth = input->size[dimd];
iheight = input->size[dimh];
iwidth = input->size[dimw];
@@ -43,6 +47,9 @@ static inline void THNN_(VolumetricReplicationPadding_shapeCheck)(
idepth, iheight, iwidth, odepth, oheight, owidth);
if (gradOutput != NULL) {
+ THArgCheck(nslices == THTensor_(size)(gradOutput, dimslices), 3,
+ "gradOutput width unexpected. Expected: %d, Got: %d",
+ nslices, THTensor_(size)(gradOutput, dimslices));
THArgCheck(owidth == THTensor_(size)(gradOutput, dimw), 3,
"gradOutput width unexpected. Expected: %d, Got: %d",
owidth, THTensor_(size)(gradOutput, dimw));