Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THCUNN/generic/SpatialDilatedConvolution.cu')
-rw-r--r--lib/THCUNN/generic/SpatialDilatedConvolution.cu12
1 files changed, 12 insertions, 0 deletions
diff --git a/lib/THCUNN/generic/SpatialDilatedConvolution.cu b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
index 02a640b..01c97c9 100644
--- a/lib/THCUNN/generic/SpatialDilatedConvolution.cu
+++ b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
@@ -89,6 +89,9 @@ void THNN_(SpatialDilatedConvolution_updateOutput)(
int nOutputPlane = weight->size[0];
input = THCTensor_(newContiguous)(state, input);
+ weight = THCTensor_(newContiguous)(state, weight);
+ bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
+
int batch = 1;
if (input->nDimension == 3) {
// Force batch
@@ -203,6 +206,8 @@ void THNN_(SpatialDilatedConvolution_updateOutput)(
}
THCTensor_(free)(state, input);
+ THCTensor_(free)(state, weight);
+ if (bias) THCTensor_(free)(state, bias);
}
void THNN_(SpatialDilatedConvolution_updateGradInput)(
@@ -229,6 +234,8 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)(
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ weight = THCTensor_(newContiguous)(state, weight);
+
int batch = 1;
if (input->nDimension == 3) {
// Force batch
@@ -308,6 +315,7 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)(
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
+ THCTensor_(free)(state, weight);
}
void THNN_(SpatialDilatedConvolution_accGradParameters)(
@@ -333,6 +341,10 @@ void THNN_(SpatialDilatedConvolution_accGradParameters)(
(state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW,
dilationH, dilationW);
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
+ if (gradBias)
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
+
// Params
int nInputPlane = gradWeight->size[1];
int nOutputPlane = gradWeight->size[0];