Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THCUNN/generic/VolumetricFullConvolution.cu')
-rw-r--r--lib/THCUNN/generic/VolumetricFullConvolution.cu13
1 files changed, 12 insertions, 1 deletions
diff --git a/lib/THCUNN/generic/VolumetricFullConvolution.cu b/lib/THCUNN/generic/VolumetricFullConvolution.cu
index eb8e9e2..9dd266c 100644
--- a/lib/THCUNN/generic/VolumetricFullConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricFullConvolution.cu
@@ -101,6 +101,8 @@ void THNN_(VolumetricFullConvolution_updateOutput)(
adjT, adjW, adjH);
input = THCTensor_(newContiguous)(state, input);
+ weight = THCTensor_(newContiguous)(state, weight);
+ bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
int batch = 1;
if (input->nDimension == 4) {
@@ -216,6 +218,9 @@ void THNN_(VolumetricFullConvolution_updateOutput)(
}
THCTensor_(free)(state, input);
+ THCTensor_(free)(state, weight);
+ if (bias) THCTensor_(free)(state, bias);
+
}
void THNN_(VolumetricFullConvolution_updateGradInput)(
@@ -247,7 +252,8 @@ void THNN_(VolumetricFullConvolution_updateGradInput)(
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
-
+ weight = THCTensor_(newContiguous)(state, weight);
+
int batch = 1;
if (input->nDimension == 4) {
// Force batch
@@ -331,6 +337,7 @@ void THNN_(VolumetricFullConvolution_updateGradInput)(
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
+ THCTensor_(free)(state, weight);
}
@@ -364,6 +371,10 @@ void THNN_(VolumetricFullConvolution_accGradParameters)(
gradBias, dT, dW, dH, padT, padW, padH,
adjT, adjW, adjH);
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
+ if (gradBias)
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
+
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);