diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-12-01 00:35:07 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-12-01 00:58:22 +0300 |
commit | 92eac4244fe7ecce915665bcea2c2fed43289429 (patch) | |
tree | 130c23d28adbdb42c7ce85c605e022c483d3818d | |
parent | c26131cc79bbe5a90b0e3f82b786af229863148d (diff) |
Add newContiguous calls that have been removed from lua.
-rw-r--r-- | lib/THCUNN/generic/BatchNormalization.cu | 1 | ||||
-rw-r--r-- | lib/THCUNN/generic/LookupTable.cu | 4 | ||||
-rw-r--r-- | lib/THCUNN/generic/SpatialConvolutionLocal.cu | 2 | ||||
-rw-r--r-- | lib/THCUNN/generic/SpatialConvolutionMM.cu | 13 | ||||
-rw-r--r-- | lib/THCUNN/generic/SpatialDilatedConvolution.cu | 13 | ||||
-rw-r--r-- | lib/THCUNN/generic/SpatialFullConvolution.cu | 13 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricDilatedConvolution.cu | 13 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricFullConvolution.cu | 15 |
8 files changed, 73 insertions, 1 deletions
diff --git a/lib/THCUNN/generic/BatchNormalization.cu b/lib/THCUNN/generic/BatchNormalization.cu index acb0b18..cbe99f3 100644 --- a/lib/THCUNN/generic/BatchNormalization.cu +++ b/lib/THCUNN/generic/BatchNormalization.cu @@ -37,6 +37,7 @@ void THNN_(BatchNormalization_updateOutput)( THCTensor *runningVar_, THCTensor *saveMean_, THCTensor *saveStd_, bool train, double momentum, double eps) { + THCTensor_(resizeAs)(state, output_, input_); DeviceTensor3 input = devicetensor<3>(state, input_); DeviceTensor3 output = devicetensor<3>(state, output_); DeviceTensor1 weight = devicetensor<1>(state, weight_); diff --git a/lib/THCUNN/generic/LookupTable.cu b/lib/THCUNN/generic/LookupTable.cu index d1e99ab..85423e1 100644 --- a/lib/THCUNN/generic/LookupTable.cu +++ b/lib/THCUNN/generic/LookupTable.cu @@ -15,8 +15,8 @@ void THNN_(LookupTable_accGradParameters)( real scale) { THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, sorted, indices); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); if (!(THCIndexTensor_(isContiguous)(state, input) && - THCTensor_(isContiguous)(state, gradOutput) && THCTensor_(isContiguous)(state, gradWeight))) { THError("Tensors must be contiguous"); @@ -108,6 +108,8 @@ void THNN_(LookupTable_accGradParameters)( stride, paddingValue ); + + THCTensor_(free)(state, gradOutput); THCudaCheck(cudaGetLastError()); } diff --git a/lib/THCUNN/generic/SpatialConvolutionLocal.cu b/lib/THCUNN/generic/SpatialConvolutionLocal.cu index 6fe52a5..afbc24d 100644 --- a/lib/THCUNN/generic/SpatialConvolutionLocal.cu +++ b/lib/THCUNN/generic/SpatialConvolutionLocal.cu @@ -203,6 +203,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)( inputHeight, inputWidth, outputHeight, outputWidth); input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); long nInputPlane = THCTensor_(size)(state,weight,2)/(kW*kH); long nOutputPlane = THCTensor_(size)(state,weight,1); @@ -290,6 +291,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)( THCTensor_(transpose)(state, weight, weight, 1, 2); THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); if (freeWeight) THCTensor_(free)(state, weight); } diff --git a/lib/THCUNN/generic/SpatialConvolutionMM.cu b/lib/THCUNN/generic/SpatialConvolutionMM.cu index 71d1155..01848f4 100644 --- a/lib/THCUNN/generic/SpatialConvolutionMM.cu +++ b/lib/THCUNN/generic/SpatialConvolutionMM.cu @@ -85,6 +85,7 @@ void THNN_(SpatialConvolutionMM_updateOutput)( THNN_(SpatialConvolutionMM_shapeCheck) (state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW); + input = THCTensor_(newContiguous)(state, input); int batch = 1; if (input->nDimension == 3) { // Force batch @@ -198,6 +199,8 @@ void THNN_(SpatialConvolutionMM_updateOutput)( THCTensor_(resize3d)(state, output, nOutputPlane, outputHeight, outputWidth); THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); } void THNN_(SpatialConvolutionMM_updateGradInput)( @@ -230,6 +233,8 @@ void THNN_(SpatialConvolutionMM_updateGradInput)( THNN_(SpatialConvolutionMM_shapeCheck) (state, input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW); + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); int batch = 1; if (input->nDimension == 3) { // Force batch @@ -307,6 +312,9 @@ void THNN_(SpatialConvolutionMM_updateGradInput)( THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); THCTensor_(resize3d)(state, gradInput, nInputPlane, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); } void THNN_(SpatialConvolutionMM_accGradParameters)( @@ -342,6 +350,8 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( THNN_(SpatialConvolutionMM_shapeCheck) (state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW); + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); int batch = 1; if (input->nDimension == 3) { // Force batch @@ -460,6 +470,9 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( THCTensor_(resize3d)(state, gradOutput, nOutputPlane, outputHeight, outputWidth); THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); } #endif diff --git a/lib/THCUNN/generic/SpatialDilatedConvolution.cu b/lib/THCUNN/generic/SpatialDilatedConvolution.cu index ae8a2cd..c790ab4 100644 --- a/lib/THCUNN/generic/SpatialDilatedConvolution.cu +++ b/lib/THCUNN/generic/SpatialDilatedConvolution.cu @@ -84,6 +84,7 @@ void THNN_(SpatialDilatedConvolution_updateOutput)( int nInputPlane = weight->size[1]; int nOutputPlane = weight->size[0]; + input = THCTensor_(newContiguous)(state, input); int batch = 1; if (input->nDimension == 3) { // Force batch @@ -196,6 +197,8 @@ void THNN_(SpatialDilatedConvolution_updateOutput)( THCTensor_(resize3d)(state, output, nOutputPlane, outputHeight, outputWidth); THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); } void THNN_(SpatialDilatedConvolution_updateGradInput)( @@ -220,6 +223,8 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)( int nInputPlane = weight->size[1]; int nOutputPlane = weight->size[0]; + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); int batch = 1; if (input->nDimension == 3) { // Force batch @@ -296,6 +301,9 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)( THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); THCTensor_(resize3d)(state, gradInput, nInputPlane, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); } void THNN_(SpatialDilatedConvolution_accGradParameters)( @@ -324,6 +332,8 @@ void THNN_(SpatialDilatedConvolution_accGradParameters)( int nInputPlane = gradWeight->size[1]; int nOutputPlane = gradWeight->size[0]; + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); int batch = 1; if (input->nDimension == 3) { // Force batch @@ -441,6 +451,9 @@ void THNN_(SpatialDilatedConvolution_accGradParameters)( THCTensor_(resize3d)(state, gradOutput, nOutputPlane, outputHeight, outputWidth); THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); } #endif diff --git a/lib/THCUNN/generic/SpatialFullConvolution.cu b/lib/THCUNN/generic/SpatialFullConvolution.cu index 1da1d0d..395e3c6 100644 --- a/lib/THCUNN/generic/SpatialFullConvolution.cu +++ b/lib/THCUNN/generic/SpatialFullConvolution.cu @@ -76,6 +76,7 @@ void THNN_(SpatialFullConvolution_updateOutput)( THNN_(SpatialFullConvolution_shapeCheck) (state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW, adjH, adjW); + input = THCTensor_(newContiguous)(state, input); int batch = 1; if (input->nDimension == 3) { // Force batch @@ -186,6 +187,8 @@ void THNN_(SpatialFullConvolution_updateOutput)( THCTensor_(resize3d)(state, output, nOutputPlane, outputHeight, outputWidth); THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); } void THNN_(SpatialFullConvolution_updateGradInput)( @@ -208,6 +211,8 @@ void THNN_(SpatialFullConvolution_updateGradInput)( THNN_(SpatialFullConvolution_shapeCheck) (state, input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW, adjH, adjW); + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); int batch = 1; if (input->nDimension == 3) { // Force batch @@ -285,6 +290,9 @@ void THNN_(SpatialFullConvolution_updateGradInput)( THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); THCTensor_(resize3d)(state, gradInput, nInputPlane, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); } @@ -310,6 +318,8 @@ void THNN_(SpatialFullConvolution_accGradParameters)( THNN_(SpatialFullConvolution_shapeCheck) (state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW, adjH, adjW); + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); int batch = 1; if (input->nDimension == 3) { // Force batch @@ -426,6 +436,9 @@ void THNN_(SpatialFullConvolution_accGradParameters)( THCTensor_(resize3d)(state, gradOutput, nOutputPlane, outputHeight, outputWidth); THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); } #endif diff --git a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu index a0214c8..268d690 100644 --- a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu +++ b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu @@ -33,6 +33,7 @@ void THNN_(VolumetricDilatedConvolution_updateOutput)( int nInputPlane = weight->size[1]; int nOutputPlane = weight->size[0]; + input = THCTensor_(newContiguous)(state, input); int batch = 1; if (input->nDimension == 4) { THArgCheck(input->size[0] == nInputPlane, 2, "input channels and nInputPlane dont match"); @@ -155,6 +156,8 @@ void THNN_(VolumetricDilatedConvolution_updateOutput)( THCTensor_(resize4d)(state, output, nOutputPlane, outputDepth, outputHeight, outputWidth); THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); } void THNN_(VolumetricDilatedConvolution_updateGradInput)( @@ -186,6 +189,8 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)( int nInputPlane = weight->size[1]; int nOutputPlane = weight->size[0]; + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); int batch = 1; if (input->nDimension == 4) { // Force batch @@ -265,6 +270,9 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)( THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth); THCTensor_(resize4d)(state, gradInput, nInputPlane, inputDepth, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); } void THNN_(VolumetricDilatedConvolution_accGradParameters)( @@ -300,6 +308,8 @@ void THNN_(VolumetricDilatedConvolution_accGradParameters)( int nInputPlane = gradWeight->size[1]; int nOutputPlane = gradWeight->size[0]; + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); int batch = 1; if (input->nDimension == 4) { // Force batch @@ -419,6 +429,9 @@ void THNN_(VolumetricDilatedConvolution_accGradParameters)( THCTensor_(resize4d)(state, gradOutput, nOutputPlane, outputDepth, outputHeight, outputWidth); THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); } #endif diff --git a/lib/THCUNN/generic/VolumetricFullConvolution.cu b/lib/THCUNN/generic/VolumetricFullConvolution.cu index c794ade..f48566c 100644 --- a/lib/THCUNN/generic/VolumetricFullConvolution.cu +++ b/lib/THCUNN/generic/VolumetricFullConvolution.cu @@ -32,6 +32,7 @@ void THNN_(VolumetricFullConvolution_updateOutput)( "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " "expected for weight, but got: %s"); + input = THCTensor_(newContiguous)(state, input); int batch = 1; if (input->nDimension == 4) { @@ -147,6 +148,8 @@ void THNN_(VolumetricFullConvolution_updateOutput)( THCTensor_(resize4d)(state, output, nOutputPlane, outputDepth, outputHeight, outputWidth); THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); } void THNN_(VolumetricFullConvolution_updateGradInput)( @@ -177,6 +180,9 @@ void THNN_(VolumetricFullConvolution_updateGradInput)( "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " "expected for weight, but got: %s"); + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); + int batch = 1; if (input->nDimension == 4) { // Force batch @@ -257,6 +263,9 @@ void THNN_(VolumetricFullConvolution_updateGradInput)( THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth); THCTensor_(resize4d)(state, gradInput, nInputPlane, inputDepth, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); } @@ -291,6 +300,9 @@ void THNN_(VolumetricFullConvolution_accGradParameters)( "expected for gradWeight, but got: %s"); + input = THCTensor_(newContiguous)(state, input); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); + int batch = 1; if (input->nDimension == 4) { // Force batch @@ -408,6 +420,9 @@ void THNN_(VolumetricFullConvolution_accGradParameters)( THCTensor_(resize4d)(state, gradOutput, nOutputPlane, outputDepth, outputHeight, outputWidth); THCTensor_(resize4d)(state, input, nInputPlane, inputDepth, inputHeight, inputWidth); } + + THCTensor_(free)(state, input); + THCTensor_(free)(state, gradOutput); } #endif |