Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-11-15 20:11:39 +0300
committerGregory Chanan <gchanan@fb.com>2016-11-16 01:16:48 +0300
commit873211115de9fc29826c21e4bbfcf4584233798a (patch)
tree1669a340f7976f00a2dcaf1a5d0cb11fbe473106
parent2c571251ebff059771d5f99548cb0797ae0c56f4 (diff)
SpatialSubSampling contiguous check.
-rw-r--r--lib/THCUNN/generic/SpatialSubSampling.cu41
1 files changed, 34 insertions, 7 deletions
diff --git a/lib/THCUNN/generic/SpatialSubSampling.cu b/lib/THCUNN/generic/SpatialSubSampling.cu
index edcea6f..ed07415 100644
--- a/lib/THCUNN/generic/SpatialSubSampling.cu
+++ b/lib/THCUNN/generic/SpatialSubSampling.cu
@@ -4,6 +4,37 @@
#include "../common.h"
+static inline void THNN_(SpatialSubSampling_shapeCheck)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *weight,
+ int kW, int kH) {
+ THCUNN_argCheck(state, input->nDimension == 3 || input->nDimension == 4, 2, input,
+ "3D or 4D input tensor expected but got: %s");
+
+ int nInputPlane = THCTensor_(size)(state, weight, 0);
+
+ int dimc = 2;
+ int dimr = 1;
+ int dimp = 0;
+
+ if (input->nDimension == 4) {
+ dimc++;
+ dimr++;
+ dimp++;
+ }
+
+ long nInputCols = input->size[dimc];
+ long nInputRows = input->size[dimr];
+ THArgCheck(input->size[dimp] == nInputPlane, 2, "invalid number of input planes");
+ THArgCheck(nInputCols >= kW && nInputRows >= kH, 2, "input image smaller than kernel size");
+
+ if (gradOutput != NULL) {
+ THArgCheck(THCTensor_(isContiguous)(state, gradOutput), 3, "gradOutput must be contiguous");
+ }
+}
+
void THNN_(SpatialSubSampling_updateOutput)(
THCState *state,
THCTensor *input,
@@ -21,7 +52,7 @@ void THNN_(SpatialSubSampling_updateOutput)(
int nInputPlane = THCTensor_(size)(state, weight, 0);
THCUNN_assertSameGPU(state, 4, input, output, weight, bias);
- THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch) tensor expected");
+ THNN_(SpatialSubSampling_shapeCheck)(state, input, NULL, weight, kW, kH);
if (input->nDimension == 3) {
long nInputCols = input->size[2];
@@ -29,9 +60,6 @@ void THNN_(SpatialSubSampling_updateOutput)(
long nOutputCols = (nInputCols - kW) / dW + 1;
long nOutputRows = (nInputRows - kH) / dH + 1;
- THArgCheck(input->size[0] == nInputPlane, 2, "invalid number of input planes");
- THArgCheck(nInputCols >= kW && nInputRows >= kH, 2, "input image smaller than kernel size");
-
input = THCTensor_(newContiguous)(state, input);
input_data = THCTensor_(data)(state, input);
@@ -56,9 +84,6 @@ void THNN_(SpatialSubSampling_updateOutput)(
long nOutputCols = (nInputCols - kW) / dW + 1;
long nOutputRows = (nInputRows - kH) / dH + 1;
- THArgCheck(input->size[1] == nInputPlane, 2, "invalid number of input planes");
- THArgCheck(nInputCols >= kW && nInputRows >= kH, 2, "input image smaller than kernel size");
-
input = THCTensor_(newContiguous)(state, input);
input_data = THCTensor_(data)(state, input);
@@ -93,6 +118,7 @@ void THNN_(SpatialSubSampling_updateGradInput)(
int dW, int dH)
{
THCUNN_assertSameGPU(state, 4, input, gradOutput, weight, gradInput);
+ THNN_(SpatialSubSampling_shapeCheck)(state, input, gradOutput, weight, kW, kH);
int nInputPlane = THCTensor_(size)(state, weight, 0);
@@ -169,6 +195,7 @@ void THNN_(SpatialSubSampling_accGradParameters)(
float scale)
{
THCUNN_assertSameGPU(state, 4, input, gradOutput, gradWeight, gradBias);
+ THNN_(SpatialSubSampling_shapeCheck)(state, input, gradOutput, gradWeight, kW, kH);
int nInputPlane = THCTensor_(size)(state, gradWeight, 0);