diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-08-18 17:25:17 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-18 17:25:17 +0300 |
commit | 8d07b574d01f013edd8ae0a7df067788301096a2 (patch) | |
tree | eb6919adf6164a75163bb9949a10ee7a223ddfae | |
parent | 5cfd34497d0cc7e180ee370767a3ca684176a10c (diff) | |
parent | 85e46603dfada0d45af6ba5e5c3ad95c536a04e6 (diff) |
Merge pull request #934 from apaszke/spatial_conv
Accept both 2D and 4D weights in SpatialConvolutionMM
-rw-r--r-- | SpatialConvolution.lua | 23 | ||||
-rw-r--r-- | lib/THNN/generic/SpatialConvolutionMM.c | 47 |
2 files changed, 45 insertions, 25 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 8324f95..01a08cd 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -89,25 +89,9 @@ local function makeContiguous(self, input, gradOutput) return input, gradOutput end --- function to re-view the weight layout in a way that would make the MM ops happy -local function viewWeight(self) - self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW) - if self.gradWeight and self.gradWeight:dim() > 0 then - self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane * self.kH * self.kW) - end -end - -local function unviewWeight(self) - self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW) - if self.gradWeight and self.gradWeight:dim() > 0 then - self.gradWeight = self.gradWeight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW) - end -end - function SpatialConvolution:updateOutput(input) assert(input.THNN, torch.type(input)..'.THNN backend not imported') backCompatibility(self) - viewWeight(self) input = makeContiguous(self, input) input.THNN.SpatialConvolutionMM_updateOutput( input:cdata(), @@ -120,7 +104,6 @@ function SpatialConvolution:updateOutput(input) self.dW, self.dH, self.padW, self.padH ) - unviewWeight(self) return self.output end @@ -128,20 +111,18 @@ function SpatialConvolution:updateGradInput(input, gradOutput) assert(input.THNN, torch.type(input)..'.THNN backend not imported') if self.gradInput then backCompatibility(self) - viewWeight(self) input, gradOutput = makeContiguous(self, input, gradOutput) input.THNN.SpatialConvolutionMM_updateGradInput( input:cdata(), gradOutput:cdata(), self.gradInput:cdata(), - self.weight:cdata(), + self.weight:cdata(), self.finput:cdata(), self.fgradInput:cdata(), self.kW, self.kH, self.dW, self.dH, self.padW, self.padH ) - unviewWeight(self) return self.gradInput end end @@ -151,7 +132,6 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) scale = scale or 1 backCompatibility(self) input, gradOutput = makeContiguous(self, input, gradOutput) - viewWeight(self) input.THNN.SpatialConvolutionMM_accGradParameters( input:cdata(), gradOutput:cdata(), @@ -164,7 +144,6 @@ function SpatialConvolution:accGradParameters(input, gradOutput, scale) self.padW, self.padH, scale ) - unviewWeight(self) end function SpatialConvolution:type(type,tensorCache) diff --git a/lib/THNN/generic/SpatialConvolutionMM.c b/lib/THNN/generic/SpatialConvolutionMM.c index e7460c8..1ca68a9 100644 --- a/lib/THNN/generic/SpatialConvolutionMM.c +++ b/lib/THNN/generic/SpatialConvolutionMM.c @@ -67,9 +67,12 @@ void THNN_(SpatialConvolutionMM_updateOutput)( long outputWidth; long outputHeight; + int freeWeight = 0; + THArgCheck( input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch mode) tensor expected"); THArgCheck(kW > 0 && kH > 0, 8, "kernel size should be greater than zero"); THArgCheck(dW > 0 && dH > 0, 10, "stride should be greater than zero"); + THArgCheck(weight->nDimension == 2 || weight->nDimension == 4, 4, "weight tensor should be 2D or 4D"); if (input->nDimension == 4) { dimf++; @@ -88,8 +91,19 @@ void THNN_(SpatialConvolutionMM_updateOutput)( THError("Given input size: (%dx%dx%d). Calculated output size: (%dx%dx%d). Output size is too small", nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth); - if (nInputPlane*kW*kH != weight->size[1]) - THError("Wrong number of input channels! Input has %d channels, expected %d",nInputPlane,weight->size[1]/(kW*kH)); + + int expectedWeightSize = weight->nDimension == 2 ? nInputPlane*kW*kH : nInputPlane; + int weightInputPlanes = weight->nDimension == 2 ? weight->size[1]/(kW*kH) : weight->size[1]; + if (expectedWeightSize != weight->size[1]) + THError("Wrong number of input channels! Input has %d channels, expected %d", + nInputPlane, weightInputPlanes); + + if (weight->nDimension == 4) { + long s1 = weight->size[0]; + long s2 = weight->size[1] * weight->size[2] * weight->size[3]; + weight = THTensor_(newWithStorage2d)(weight->storage, 0, s1, -1, s2, -1); + freeWeight = 1; + } if(input->nDimension == 3) { @@ -126,6 +140,9 @@ void THNN_(SpatialConvolutionMM_updateOutput)( THTensor_(free)(finput_t); } } + + if (freeWeight) + THTensor_(free)(weight); } static void THNN_(SpatialConvolutionMM_updateGradInput_frame)( @@ -167,17 +184,27 @@ void THNN_(SpatialConvolutionMM_updateGradInput)( int padH) { long nOutputPlane = weight->size[0]; + int freeWeight = 0; THArgCheck( nOutputPlane == gradOutput->size[input->nDimension == 4 ? 1 : 0], 3, "Number of output features is not equal to nOutputPlane" ); THArgCheck(kW > 0 && kH > 0, 9, "kernel size should be greater than zero"); THArgCheck(dW > 0 && dH > 0, 11, "stride should be greater than zero"); + THArgCheck(weight->nDimension == 2 || weight->nDimension == 4, 4, "weight tensor should be 2D or 4D"); THTensor_(resizeAs)(gradInput, input); THTensor_(resizeAs)(fgradInput, finput); // depending on the BLAS library, fgradInput (result tensor) might // be left uninitialized on zero alpha, which might lead to weird behavior // hence, to be safe, zero it - THTensor_(zero)(fgradInput); + THTensor_(zero)(fgradInput); + + if (weight->nDimension == 4) { + long s1 = weight->size[0]; + long s2 = weight->size[1] * weight->size[2] * weight->size[3]; + weight = THTensor_(newWithStorage2d)(weight->storage, 0, s1, -1, s2, -1); + freeWeight = 1; + } + THTensor_(transpose)(weight, weight, 0, 1); if(input->nDimension == 3) @@ -205,6 +232,9 @@ void THNN_(SpatialConvolutionMM_updateGradInput)( } THTensor_(transpose)(weight, weight, 0, 1); + + if (freeWeight) + THTensor_(free)(weight); } static void THNN_(SpatialConvolutionMM_accGradParameters_frame)( @@ -254,10 +284,19 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( int padH, real scale) { + int freeWeight = 0; long nOutputPlane = gradWeight->size[0]; THArgCheck( nOutputPlane == gradOutput->size[input->nDimension == 4 ? 1 : 0], 3, "Number of output features is not equal to nOutputPlane" ); THArgCheck(kW > 0 && kH > 0, 8, "kernel size should be greater than zero"); THArgCheck(dW > 0 && dH > 0, 10, "stride should be greater than zero"); + THArgCheck(gradWeight->nDimension == 2 || gradWeight->nDimension == 4, 4, "gradWeight tensor should be 2D or 4D"); + + if (gradWeight->nDimension == 4) { + long s1 = gradWeight->size[0]; + long s2 = gradWeight->size[1] * gradWeight->size[2] * gradWeight->size[3]; + gradWeight = THTensor_(newWithStorage2d)(gradWeight->storage, 0, s1, -1, s2, -1); + freeWeight = 1; + } if(input->nDimension == 3) { @@ -279,6 +318,8 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( THTensor_(free)(finput_t); } } + if (freeWeight) + THTensor_(free)(gradWeight); } #endif |