Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THCUNN/generic/SpatialFullConvolution.cu')
-rw-r--r--lib/THCUNN/generic/SpatialFullConvolution.cu11
1 files changed, 11 insertions, 0 deletions
diff --git a/lib/THCUNN/generic/SpatialFullConvolution.cu b/lib/THCUNN/generic/SpatialFullConvolution.cu
index 9e8d30f..76abb90 100644
--- a/lib/THCUNN/generic/SpatialFullConvolution.cu
+++ b/lib/THCUNN/generic/SpatialFullConvolution.cu
@@ -84,6 +84,9 @@ void THNN_(SpatialFullConvolution_updateOutput)(
(state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW, adjH, adjW);
input = THCTensor_(newContiguous)(state, input);
+ weight = THCTensor_(newContiguous)(state, weight);
+ bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
+
int batch = 1;
if (input->nDimension == 3) {
// Force batch
@@ -195,6 +198,9 @@ void THNN_(SpatialFullConvolution_updateOutput)(
}
THCTensor_(free)(state, input);
+ THCTensor_(free)(state, weight);
+ if (bias) THCTensor_(free)(state, bias);
+
}
void THNN_(SpatialFullConvolution_updateGradInput)(
@@ -219,6 +225,7 @@ void THNN_(SpatialFullConvolution_updateGradInput)(
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ weight = THCTensor_(newContiguous)(state, weight);
int batch = 1;
if (input->nDimension == 3) {
// Force batch
@@ -299,6 +306,7 @@ void THNN_(SpatialFullConvolution_updateGradInput)(
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
+ THCTensor_(free)(state, weight);
}
@@ -325,6 +333,9 @@ void THNN_(SpatialFullConvolution_accGradParameters)(
THNN_(SpatialFullConvolution_shapeCheck)
(state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW, adjH, adjW);
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
+ if (gradBias)
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
int batch = 1;