diff options
author | Soumith Chintala <soumith@fb.com> | 2016-10-08 04:14:33 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-10-17 07:37:50 +0300 |
commit | fcdf644d0d2986932ce38149b05778225b2c9b5d (patch) | |
tree | 57593461b7348082d3345b0b17f35f33279cb80f /SpatialUpSamplingBilinear.lua | |
parent | eb7656e621347e3d9150ed593279bae63dd004f4 (diff) |
more improvments on error messages and shape checks
Diffstat (limited to 'SpatialUpSamplingBilinear.lua')
-rw-r--r-- | SpatialUpSamplingBilinear.lua | 62 |
1 files changed, 36 insertions, 26 deletions
diff --git a/SpatialUpSamplingBilinear.lua b/SpatialUpSamplingBilinear.lua index 9e1b3d6..54ce5b8 100644 --- a/SpatialUpSamplingBilinear.lua +++ b/SpatialUpSamplingBilinear.lua @@ -51,37 +51,40 @@ local function makeContiguous(self, input, gradOutput) return input, gradOutput end -function SpatialUpSamplingBilinear:updateOutput(input) - assert(input:dim() == 4 or input:dim()==3, - 'SpatialUpSamplingBilinear only supports 3D or 4D tensors' ) - local inputwas3D = false - if input:dim() == 3 then - input=input:view(-1, input:size(1), input:size(2), input:size(3)) - inputwas3D = true - end - input = makeContiguous(self, input) - assert(input:dim() == 4) - -- Copy the input size +function SpatialUpSamplingBilinear:setSize(input) local xdim = input:dim() - local ydim = input:dim() - 1 + local ydim = xdim - 1 for i = 1, input:dim() do self.inputSize[i] = input:size(i) self.outputSize[i] = input:size(i) end if self.scale_factor ~= nil then - self.outputSize[ydim] = (self.outputSize[ydim]-1) * (self.scale_factor-1) - + self.outputSize[ydim] - self.outputSize[xdim] = (self.outputSize[xdim]-1) * (self.scale_factor -1) - + self.outputSize[xdim] + self.outputSize[ydim] = self.outputSize[ydim] * self.scale_factor + self.outputSize[xdim] = self.outputSize[xdim] * self.scale_factor else self.outputSize[ydim] = self.oheight self.outputSize[xdim] = self.owidth end - -- Resize the output if needed +end + +function SpatialUpSamplingBilinear:updateOutput(input) + assert(input:dim() == 4 or input:dim()==3, + 'SpatialUpSamplingBilinear only supports 3D or 4D tensors' ) + local inputwas3D = false + if input:dim() == 3 then + input=input:view(-1, input:size(1), input:size(2), input:size(3)) + inputwas3D = true + end + input = makeContiguous(self, input) + local xdim = input:dim() + local ydim = xdim - 1 + self:setSize(input) self.output:resize(self.outputSize) input.THNN.SpatialUpSamplingBilinear_updateOutput( input:cdata(), - self.output:cdata() + self.output:cdata(), + self.outputSize[ydim], + self.outputSize[xdim] ) if inputwas3D then input = input:squeeze(1) @@ -94,19 +97,26 @@ function SpatialUpSamplingBilinear:updateGradInput(input, gradOutput) assert(input:dim() == 4 or input:dim()==3, 'SpatialUpSamplingBilinear only support 3D or 4D tensors' ) assert(input:dim() == gradOutput:dim(), - 'Input and gradOutput should be of same dimension' ) + 'Input and gradOutput should be of same dimension' ) local inputwas3D = false if input:dim() == 3 then - input=input:view(-1, input:size(1), input:size(2), input:size(3)) - gradOutput=gradOutput:view(-1, gradOutput:size(1), gradOutput:size(2), - gradOutput:size(3)) + input = input:view(-1, input:size(1), input:size(2), input:size(3)) + gradOutput = gradOutput:view(-1, gradOutput:size(1), gradOutput:size(2), + gradOutput:size(3)) inputwas3D = true end - assert(input:dim() == 4 and gradOutput:dim() == 4) - self.gradInput:resizeAs(input) + local xdim = input:dim() + local ydim = xdim - 1 + self.gradInput:resizeAs(input) input.THNN.SpatialUpSamplingBilinear_updateGradInput( gradOutput:cdata(), - self.gradInput:cdata() + self.gradInput:cdata(), + input:size(1), + input:size(2), + input:size(3), + input:size(4), + self.outputSize[ydim], + self.outputSize[xdim] ) if inputwas3D then input = input:squeeze(1) @@ -126,4 +136,4 @@ function SpatialUpSamplingBilinear:__tostring__() torch.type(self), self.oheight, self.owidth) end return s -end
\ No newline at end of file +end |