Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-04-22 11:05:43 +0300
committerGitHub <noreply@github.com>2017-04-22 11:05:43 +0300
commit455e488488bdaa20a50f82586975e69a0332c97a (patch)
tree5ddace698e0db1297534588fdeb4076c6267975e
parentea15d0d649edc52fc3e25a26a14c9bcc0070339a (diff)
parent0112e3f31d817c0a15fbbfb449ffc44f902131b1 (diff)
Merge pull request #465 from torch/cunnchecks
add contiguous checks
-rw-r--r--lib/THCUNN/generic/PReLU.cu5
-rw-r--r--lib/THCUNN/generic/SparseLinear.cu5
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionLocal.cu30
-rw-r--r--lib/THCUNN/generic/SpatialDilatedConvolution.cu12
-rw-r--r--lib/THCUNN/generic/SpatialFullConvolution.cu11
-rw-r--r--lib/THCUNN/generic/TemporalConvolution.cu3
-rw-r--r--lib/THCUNN/generic/TemporalRowConvolution.cu5
-rw-r--r--lib/THCUNN/generic/VolumetricDilatedConvolution.cu8
-rw-r--r--lib/THCUNN/generic/VolumetricFullConvolution.cu13
9 files changed, 76 insertions, 16 deletions
diff --git a/lib/THCUNN/generic/PReLU.cu b/lib/THCUNN/generic/PReLU.cu
index 16ea7ad..949e3d9 100644
--- a/lib/THCUNN/generic/PReLU.cu
+++ b/lib/THCUNN/generic/PReLU.cu
@@ -11,6 +11,7 @@ void THNN_(PReLU_updateOutput)(
{
THCTensor_(resizeAs)(state, output, input);
+ weight = THCTensor_(newContiguous)(state, weight);
real *w = THCTensor_(data)(state, weight);
if (nOutputPlane == 0)
@@ -40,6 +41,8 @@ void THNN_(PReLU_updateOutput)(
THCudaCheck(cudaGetLastError());
THCTensor_(free)(state, input);
}
+
+ THCTensor_(free)(state, weight);
}
void THNN_(PReLU_updateGradInput)(
@@ -53,6 +56,7 @@ void THNN_(PReLU_updateGradInput)(
THCUNN_check_nElement(state, input, gradOutput);
THCTensor_(resizeAs)(state, gradInput, input);
+ weight = THCTensor_(newContiguous)(state, weight);
real *w = THCTensor_(data)(state, weight);
if (nOutputPlane == 0)
{
@@ -84,6 +88,7 @@ void THNN_(PReLU_updateGradInput)(
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
}
+ THCTensor_(free)(state, weight);
}
void THNN_(PReLU_accGradParameters)(
diff --git a/lib/THCUNN/generic/SparseLinear.cu b/lib/THCUNN/generic/SparseLinear.cu
index 23a5c94..70c9f5b 100644
--- a/lib/THCUNN/generic/SparseLinear.cu
+++ b/lib/THCUNN/generic/SparseLinear.cu
@@ -44,6 +44,8 @@ void THNN_(SparseLinear_updateOutput)(
THArgCheck(THCTensor_(nDimension)(state, output) == 2, 3, "output must be batchsize x outputsize");
THArgCheck(checkSize1D(bias, outDim), 5, "bias size wrong");
+ weight = THCTensor_(newContiguous)(state, weight);
+
long batchnum = THCTensor_(size)(state, output, 0);
long nnz = THCTensor_(size)(state, input, 0);
@@ -114,6 +116,7 @@ void THNN_(SparseLinear_updateOutput)(
THCTensor_(free)(state, buffer);
THCTensor_(free)(state, sel);
THCTensor_(free)(state, values);
+ THCTensor_(free)(state, weight);
THCudaIntTensor_free(state, rowbuf);
THCudaIntTensor_free(state, colInds);
THCudaIntTensor_free(state, csrPtrs);
@@ -137,6 +140,7 @@ void THNN_(SparseLinear_accGradParameters)(
THArgCheck(checkSize2D(gradWeight, outDim, inDim), 4, "gradWeight size wrong");
THArgCheck(checkSize1D(gradBias, outDim), 5, "gradBias size wrong");
+ weight = THCTensor_(newContiguous)(state, weight);
long nnz = THCTensor_(size)(state, input, 0);
long batchnum = THCTensor_(size)(state, gradOutput, 0);
@@ -212,6 +216,7 @@ void THNN_(SparseLinear_accGradParameters)(
THCTensor_(cadd)(state, gradBias, gradBias, weightDecay, bias);
}
+ THCTensor_(free)(state, weight);
THCTensor_(free)(state, buf);
THCTensor_(free)(state, sel);
THCTensor_(free)(state, cols);
diff --git a/lib/THCUNN/generic/SpatialConvolutionLocal.cu b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
index 9cbddd1..1799449 100644
--- a/lib/THCUNN/generic/SpatialConvolutionLocal.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
@@ -48,24 +48,25 @@ static inline void THNN_(SpatialConvolutionLocal_shapeCheck)(
}
}
-static int THNN_(view_weight_local)(
+static THCTensor* THNN_(view_weight_local)(
THCState *state,
- THCTensor **_weight)
+ THCTensor *_weight)
{
- THCTensor *weight = *_weight;
+ THTensor *weight = THCTensor_(newContiguous)(state, _weight);
THArgCheck(weight->nDimension == 3 || weight->nDimension == 6, 4,
"weight tensor should be 3D or 6D - got %dD", weight->nDimension);
if (weight->nDimension == 6) {
long s1 = weight->size[0] * weight->size[1];
long s2 = weight->size[2];
long s3 = weight->size[3] * weight->size[4] * weight->size[5];
- *_weight = THCTensor_(newWithStorage3d)(state,
+ THCTensor *old_weight = weight;
+ weight = THCTensor_(newWithStorage3d)(state,
weight->storage,
weight->storageOffset,
s1, -1, s2, -1, s3, -1);
- return 1;
+ THCTensor_(free)(state, old_weight);
}
- return 0;
+ return weight;
}
void THNN_(SpatialConvolutionLocal_updateOutput)(
@@ -85,7 +86,7 @@ void THNN_(SpatialConvolutionLocal_updateOutput)(
THCUNN_assertSameGPU(state, 5, input, output, weight,
bias, finput);
- int freeWeight = THNN_(view_weight_local)(state, &weight);
+ weight = THNN_(view_weight_local)(state, weight);
THNN_(SpatialConvolutionLocal_shapeCheck)
(state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW,
@@ -175,8 +176,7 @@ void THNN_(SpatialConvolutionLocal_updateOutput)(
}
THCTensor_(free)(state, input);
- if (freeWeight)
- THCTensor_(free)(state, weight);
+ THCTensor_(free)(state, weight);
}
void THNN_(SpatialConvolutionLocal_updateGradInput)(
@@ -196,7 +196,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)(
THCUNN_assertSameGPU(state, 5, input, gradOutput, weight,
fgradInput, gradInput);
- int freeWeight = THNN_(view_weight_local)(state, &weight);
+ weight = THNN_(view_weight_local)(state, weight);
THNN_(SpatialConvolutionLocal_shapeCheck)
(state, input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW,
@@ -292,8 +292,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)(
THCTensor_(free)(state, tweight);
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
- if (freeWeight)
- THCTensor_(free)(state, weight);
+ THCTensor_(free)(state, weight);
}
void THNN_(SpatialConvolutionLocal_accGradParameters)(
@@ -315,7 +314,9 @@ void THNN_(SpatialConvolutionLocal_accGradParameters)(
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight,
gradBias, finput);
- int freeWeight = THNN_(view_weight_local)(state, &gradWeight);
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
+ gradWeight = THNN_(view_weight_local)(state, gradWeight);
THNN_(SpatialConvolutionLocal_shapeCheck)
(state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW,
@@ -400,8 +401,7 @@ void THNN_(SpatialConvolutionLocal_accGradParameters)(
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
- if (freeWeight)
- THCTensor_(free)(state, gradWeight);
+ THCTensor_(free)(state, gradWeight);
}
#endif
diff --git a/lib/THCUNN/generic/SpatialDilatedConvolution.cu b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
index 02a640b..01c97c9 100644
--- a/lib/THCUNN/generic/SpatialDilatedConvolution.cu
+++ b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
@@ -89,6 +89,9 @@ void THNN_(SpatialDilatedConvolution_updateOutput)(
int nOutputPlane = weight->size[0];
input = THCTensor_(newContiguous)(state, input);
+ weight = THCTensor_(newContiguous)(state, weight);
+ bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
+
int batch = 1;
if (input->nDimension == 3) {
// Force batch
@@ -203,6 +206,8 @@ void THNN_(SpatialDilatedConvolution_updateOutput)(
}
THCTensor_(free)(state, input);
+ THCTensor_(free)(state, weight);
+ if (bias) THCTensor_(free)(state, bias);
}
void THNN_(SpatialDilatedConvolution_updateGradInput)(
@@ -229,6 +234,8 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)(
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ weight = THCTensor_(newContiguous)(state, weight);
+
int batch = 1;
if (input->nDimension == 3) {
// Force batch
@@ -308,6 +315,7 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)(
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
+ THCTensor_(free)(state, weight);
}
void THNN_(SpatialDilatedConvolution_accGradParameters)(
@@ -333,6 +341,10 @@ void THNN_(SpatialDilatedConvolution_accGradParameters)(
(state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW,
dilationH, dilationW);
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
+ if (gradBias)
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
+
// Params
int nInputPlane = gradWeight->size[1];
int nOutputPlane = gradWeight->size[0];
diff --git a/lib/THCUNN/generic/SpatialFullConvolution.cu b/lib/THCUNN/generic/SpatialFullConvolution.cu
index 9e8d30f..76abb90 100644
--- a/lib/THCUNN/generic/SpatialFullConvolution.cu
+++ b/lib/THCUNN/generic/SpatialFullConvolution.cu
@@ -84,6 +84,9 @@ void THNN_(SpatialFullConvolution_updateOutput)(
(state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW, adjH, adjW);
input = THCTensor_(newContiguous)(state, input);
+ weight = THCTensor_(newContiguous)(state, weight);
+ bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
+
int batch = 1;
if (input->nDimension == 3) {
// Force batch
@@ -195,6 +198,9 @@ void THNN_(SpatialFullConvolution_updateOutput)(
}
THCTensor_(free)(state, input);
+ THCTensor_(free)(state, weight);
+ if (bias) THCTensor_(free)(state, bias);
+
}
void THNN_(SpatialFullConvolution_updateGradInput)(
@@ -219,6 +225,7 @@ void THNN_(SpatialFullConvolution_updateGradInput)(
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ weight = THCTensor_(newContiguous)(state, weight);
int batch = 1;
if (input->nDimension == 3) {
// Force batch
@@ -299,6 +306,7 @@ void THNN_(SpatialFullConvolution_updateGradInput)(
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
+ THCTensor_(free)(state, weight);
}
@@ -325,6 +333,9 @@ void THNN_(SpatialFullConvolution_accGradParameters)(
THNN_(SpatialFullConvolution_shapeCheck)
(state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW, adjH, adjW);
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
+ if (gradBias)
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
int batch = 1;
diff --git a/lib/THCUNN/generic/TemporalConvolution.cu b/lib/THCUNN/generic/TemporalConvolution.cu
index abe4b54..de27c30 100644
--- a/lib/THCUNN/generic/TemporalConvolution.cu
+++ b/lib/THCUNN/generic/TemporalConvolution.cu
@@ -53,6 +53,8 @@ void THNN_(TemporalConvolution_updateOutput)(
THCUNN_assertSameGPU(state, 4, input, output, weight, bias);
THNN_(TemporalConvolution_shapeCheck)
(state, input, kW, dW, &inputFrameSize);
+ THArgCheck(THCTensor_(isContiguous)(state, weight), 4, "weight must be contiguous");
+ THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5, "bias must be contiguous");
if (input->nDimension == 3)
{
@@ -180,6 +182,7 @@ void THNN_(TemporalConvolution_updateGradInput)(
int dimS = 0; // sequence dimension
THCUNN_assertSameGPU(state, 4, input, gradOutput, weight, gradInput);
+ THArgCheck(THCTensor_(isContiguous)(state, weight), 4, "weight must be contiguous");
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
diff --git a/lib/THCUNN/generic/TemporalRowConvolution.cu b/lib/THCUNN/generic/TemporalRowConvolution.cu
index 9959322..0063570 100644
--- a/lib/THCUNN/generic/TemporalRowConvolution.cu
+++ b/lib/THCUNN/generic/TemporalRowConvolution.cu
@@ -62,6 +62,9 @@ void THNN_(TemporalRowConvolution_updateOutput)(
THCUNN_assertSameGPU(state, 2, weight, bias);
}
+ THArgCheck(THCTensor_(isContiguous)(state, weight), 4, "weight must be contiguous");
+ THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5, "bias must be contiguous");
+
// reshape weight if necessary
int ndim = input->nDimension;
@@ -190,6 +193,8 @@ void THNN_(TemporalRowConvolution_updateGradInput)(
THCUNN_assertSameGPU(state, 5, input, gradOutput, weight, gradColumns,
gradInput);
+ THArgCheck(THCTensor_(isContiguous)(state, weight), 4, "weight must be contiguous");
+
int ndim = input->nDimension;
THCTensor *tinput, *tgradOutput;
diff --git a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
index ffeea7f..45bb0f6 100644
--- a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
@@ -96,6 +96,9 @@ void THNN_(VolumetricDilatedConvolution_updateOutput)(
int nOutputPlane = weight->size[0];
input = THCTensor_(newContiguous)(state, input);
+ weight = THCTensor_(newContiguous)(state, weight);
+ bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
+
int batch = 1;
if (input->nDimension == 4) {
// Force batch
@@ -213,6 +216,8 @@ void THNN_(VolumetricDilatedConvolution_updateOutput)(
}
THCTensor_(free)(state, input);
+ THCTensor_(free)(state, weight);
+ if (bias) THCTensor_(free)(state, bias);
}
void THNN_(VolumetricDilatedConvolution_updateGradInput)(
@@ -234,6 +239,8 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)(
kT, kH, kW, dT, dH, dW, padT, padH, padW,
dilationT, dilationH, dilationW);
+ weight = THCTensor_(newContiguous)(state, weight);
+
// Params
int nInputPlane = weight->size[1];
int nOutputPlane = weight->size[0];
@@ -322,6 +329,7 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)(
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
+ THCTensor_(free)(state, weight);
}
void THNN_(VolumetricDilatedConvolution_accGradParameters)(
diff --git a/lib/THCUNN/generic/VolumetricFullConvolution.cu b/lib/THCUNN/generic/VolumetricFullConvolution.cu
index eb8e9e2..9dd266c 100644
--- a/lib/THCUNN/generic/VolumetricFullConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricFullConvolution.cu
@@ -101,6 +101,8 @@ void THNN_(VolumetricFullConvolution_updateOutput)(
adjT, adjW, adjH);
input = THCTensor_(newContiguous)(state, input);
+ weight = THCTensor_(newContiguous)(state, weight);
+ bias = bias ? THCTensor_(newContiguous)(state, bias) : bias;
int batch = 1;
if (input->nDimension == 4) {
@@ -216,6 +218,9 @@ void THNN_(VolumetricFullConvolution_updateOutput)(
}
THCTensor_(free)(state, input);
+ THCTensor_(free)(state, weight);
+ if (bias) THCTensor_(free)(state, bias);
+
}
void THNN_(VolumetricFullConvolution_updateGradInput)(
@@ -247,7 +252,8 @@ void THNN_(VolumetricFullConvolution_updateGradInput)(
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
-
+ weight = THCTensor_(newContiguous)(state, weight);
+
int batch = 1;
if (input->nDimension == 4) {
// Force batch
@@ -331,6 +337,7 @@ void THNN_(VolumetricFullConvolution_updateGradInput)(
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
+ THCTensor_(free)(state, weight);
}
@@ -364,6 +371,10 @@ void THNN_(VolumetricFullConvolution_accGradParameters)(
gradBias, dT, dW, dH, padT, padW, padH,
adjT, adjW, adjH);
+ THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous");
+ if (gradBias)
+ THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous");
+
input = THCTensor_(newContiguous)(state, input);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);