diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-01-24 01:22:24 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-01-24 01:22:24 +0300 |
commit | bd59f99e7ce48aa8e53cdf3e106d7feb01685f89 (patch) | |
tree | 87449f65566e7c1b6d68e6b6671bb3d18083c600 | |
parent | 223d4a19438029f1fd25d335d628f349656ddcbd (diff) | |
parent | 1d70c17029cb3605c40a56029b3340ddf6f62c0f (diff) |
Merge pull request #423 from apaszke/contig_checks
Add contiguity checks to THCUNN
-rw-r--r-- | lib/THCUNN/generic/SpatialConvolutionMM.cu | 9 | ||||
-rw-r--r-- | lib/THCUNN/generic/SpatialDilatedConvolution.cu | 4 | ||||
-rw-r--r-- | lib/THCUNN/generic/SpatialFullConvolution.cu | 4 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricConvolution.cu | 6 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricDilatedConvolution.cu | 4 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricFullConvolution.cu | 4 |
6 files changed, 31 insertions, 0 deletions
diff --git a/lib/THCUNN/generic/SpatialConvolutionMM.cu b/lib/THCUNN/generic/SpatialConvolutionMM.cu index 01848f4..e7aeacb 100644 --- a/lib/THCUNN/generic/SpatialConvolutionMM.cu +++ b/lib/THCUNN/generic/SpatialConvolutionMM.cu @@ -11,6 +11,8 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)( "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); THArgCheck(dW > 0 && dH > 0, 11, "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5, + "bias tensor has to be contiguous"); THCUNN_argCheck(state, weight->nDimension == 2 || weight->nDimension == 4, 5, weight, "2D or 4D weight tensor expected, but got: %s"); @@ -69,6 +71,9 @@ void THNN_(SpatialConvolutionMM_updateOutput)( if (bias) { THCUNN_assertSameGPU(state, 2, weight, bias); } + THArgCheck(THCTensor_(isContiguous)(state, weight), 4, + "weight tensor has to be contiguous"); + int freeWeight = 0; // Params: @@ -217,6 +222,8 @@ void THNN_(SpatialConvolutionMM_updateGradInput)( THCUNN_assertSameGPU(state, 5, input, gradOutput, weight, gradColumns, gradInput); + THArgCheck(THCTensor_(isContiguous)(state, weight), 4, + "weight tensor has to be contiguous"); // Params int nInputPlane = weight->nDimension == 2 ? weight->size[1]/(kW*kH) : weight->size[1]; @@ -334,6 +341,8 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( if (gradBias) { THCUNN_assertSameGPU(state, 2, gradWeight, gradBias); } + THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, + "weight tensor has to be contiguous"); // Params int nInputPlane = gradWeight->nDimension == 2 ? gradWeight->size[1]/(kW*kH) : gradWeight->size[1]; diff --git a/lib/THCUNN/generic/SpatialDilatedConvolution.cu b/lib/THCUNN/generic/SpatialDilatedConvolution.cu index c790ab4..7b656d3 100644 --- a/lib/THCUNN/generic/SpatialDilatedConvolution.cu +++ b/lib/THCUNN/generic/SpatialDilatedConvolution.cu @@ -16,6 +16,10 @@ static inline void THNN_(SpatialDilatedConvolution_shapeCheck)( "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); THArgCheck(dW > 0 && dH > 0, 11, "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + THArgCheck(THCTensor_(isContiguous)(state, weight), 4, + "weight tensor has to be contiguous"); + THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5, + "bias tensor has to be contiguous"); THArgCheck(dilationW > 0 && dilationH > 0, 14, "dilation should be greater than 0, but got dilationH: %d dilationW: %d", dilationH, dilationW); diff --git a/lib/THCUNN/generic/SpatialFullConvolution.cu b/lib/THCUNN/generic/SpatialFullConvolution.cu index 12995d2..7a5d7ea 100644 --- a/lib/THCUNN/generic/SpatialFullConvolution.cu +++ b/lib/THCUNN/generic/SpatialFullConvolution.cu @@ -15,6 +15,10 @@ static inline void THNN_(SpatialFullConvolution_shapeCheck)( THArgCheck(adjW < dW && adjH < dH, 15, "output adjustment must be smaller than stride, but got adjH: %d adjW: %d dH: %d dW: %d", adjH, adjW, dH, dW); + THArgCheck(THCTensor_(isContiguous)(state, weight), 4, + "weight tensor has to be contiguous"); + THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5, + "bias tensor has to be contiguous"); THCUNN_argCheck(state, weight->nDimension == 2 || weight->nDimension == 4, 5, weight, "2D or 4D weight tensor expected, but got: %s"); diff --git a/lib/THCUNN/generic/VolumetricConvolution.cu b/lib/THCUNN/generic/VolumetricConvolution.cu index a371ac8..d6da545 100644 --- a/lib/THCUNN/generic/VolumetricConvolution.cu +++ b/lib/THCUNN/generic/VolumetricConvolution.cu @@ -17,6 +17,12 @@ static inline void THNN_(VolumetricConvolution_shapeCheck) int padH) { THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, "4D or 5D (batch mode) tensor expected for input, but got: %s"); + THArgCheck(!weight || THCTensor_(isContiguous)(state, weight), 4, + "weight tensor has to be contiguous"); + THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5, + "bias tensor has to be contiguous"); + THArgCheck(!gradWeight || THCTensor_(isContiguous)(state, gradWeight), 5, + "gradWeight tensor has to be contiguous"); THArgCheck(dT > 0 && dW > 0 && dH > 0, 10, "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW); diff --git a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu index 422cdc7..b0145a5 100644 --- a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu +++ b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu @@ -21,6 +21,10 @@ static inline void THNN_(VolumetricDilatedConvolution_shapeCheck)( "kernel size should be greater than zero, but got kT: %d kH: %d kW: %d", kT, kH, kW); THArgCheck(dT > 0 && dW > 0 && dH > 0, 11, "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW); + THArgCheck(THCTensor_(isContiguous)(state, weight), 4, + "weight tensor has to be contiguous"); + THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5, + "bias tensor has to be contiguous"); THArgCheck(dilationT > 0 && dilationW > 0 && dilationH > 0, 15, "dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d", dilationT, dilationH, dilationW); diff --git a/lib/THCUNN/generic/VolumetricFullConvolution.cu b/lib/THCUNN/generic/VolumetricFullConvolution.cu index 47f4943..334c7da 100644 --- a/lib/THCUNN/generic/VolumetricFullConvolution.cu +++ b/lib/THCUNN/generic/VolumetricFullConvolution.cu @@ -17,6 +17,10 @@ static inline void THNN_(VolumetricFullConvolution_shapeCheck)( THCUNN_argCheck(state, weight->nDimension == 5, 4, weight, "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " "expected for weight, but got: %s"); + THArgCheck(THCTensor_(isContiguous)(state, weight), 4, + "weight tensor has to be contiguous"); + THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5, + "bias tensor has to be contiguous"); THArgCheck(dT > 0 && dW > 0 && dH > 0, 8, "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW); THArgCheck(adjT < dT && adjW < dW && adjH < dH, 14, |