Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THCUNN/generic/SpatialConvolutionLocal.cu')
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionLocal.cu30
1 files changed, 15 insertions, 15 deletions
diff --git a/lib/THCUNN/generic/SpatialConvolutionLocal.cu b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
index 9cbddd1..1799449 100644
--- a/lib/THCUNN/generic/SpatialConvolutionLocal.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
@@ -48,24 +48,25 @@ static inline void THNN_(SpatialConvolutionLocal_shapeCheck)(
}
}
-static int THNN_(view_weight_local)(
+static THCTensor* THNN_(view_weight_local)(
THCState *state,
- THCTensor **_weight)
+ THCTensor *_weight)
{
- THCTensor *weight = *_weight;
+ THTensor *weight = THCTensor_(newContiguous)(state, _weight);
THArgCheck(weight->nDimension == 3 || weight->nDimension == 6, 4,
"weight tensor should be 3D or 6D - got %dD", weight->nDimension);
if (weight->nDimension == 6) {
long s1 = weight->size[0] * weight->size[1];
long s2 = weight->size[2];
long s3 = weight->size[3] * weight->size[4] * weight->size[5];
- *_weight = THCTensor_(newWithStorage3d)(state,
+ THCTensor *old_weight = weight;
+ weight = THCTensor_(newWithStorage3d)(state,
weight->storage,
weight->storageOffset,
s1, -1, s2, -1, s3, -1);
- return 1;
+ THCTensor_(free)(state, old_weight);
}
- return 0;
+ return weight;
}
void THNN_(SpatialConvolutionLocal_updateOutput)(
@@ -85,7 +86,7 @@ void THNN_(SpatialConvolutionLocal_updateOutput)(
THCUNN_assertSameGPU(state, 5, input, output, weight,
bias, finput);
- int freeWeight = THNN_(view_weight_local)(state, &weight);
+ weight = THNN_(view_weight_local)(state, weight);
THNN_(SpatialConvolutionLocal_shapeCheck)
(state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW,
@@ -175,8 +176,7 @@ void THNN_(SpatialConvolutionLocal_updateOutput)(
}
THCTensor_(free)(state, input);
- if (freeWeight)
- THCTensor_(free)(state, weight);
+ THCTensor_(free)(state, weight);
}
void THNN_(SpatialConvolutionLocal_updateGradInput)(
@@ -196,7 +196,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)(
THCUNN_assertSameGPU(state, 5, input, gradOutput, weight,
fgradInput, gradInput);
- int freeWeight = THNN_(view_weight_local)(state, &weight);
+ weight = THNN_(view_weight_local)(state, weight);
THNN_(SpatialConvolutionLocal_shapeCheck)
(state, input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW,
@@ -292,8 +292,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)(
THCTensor_(free)(state, tweight);
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
- if (freeWeight)
- THCTensor_(free)(state, weight);
+ THCTensor_(free)(state, weight);
}
void THNN_(SpatialConvolutionLocal_accGradParameters)(
@@ -315,7 +314,9 @@ void THNN_(SpatialConvolutionLocal_accGradParameters)(
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight,
gradBias, finput);
- int freeWeight = THNN_(view_weight_local)(state, &gradWeight);
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
+ gradWeight = THNN_(view_weight_local)(state, gradWeight);
THNN_(SpatialConvolutionLocal_shapeCheck)
(state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW,
@@ -400,8 +401,7 @@ void THNN_(SpatialConvolutionLocal_accGradParameters)(
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
- if (freeWeight)
- THCTensor_(free)(state, gradWeight);
+ THCTensor_(free)(state, gradWeight);
}
#endif