Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@fb.com>2016-10-08 04:14:33 +0300
committerSoumith Chintala <soumith@gmail.com>2016-10-17 07:37:50 +0300
commitfcdf644d0d2986932ce38149b05778225b2c9b5d (patch)
tree57593461b7348082d3345b0b17f35f33279cb80f /SpatialUpSamplingBilinear.lua
parenteb7656e621347e3d9150ed593279bae63dd004f4 (diff)
more improvments on error messages and shape checks
Diffstat (limited to 'SpatialUpSamplingBilinear.lua')
-rw-r--r--SpatialUpSamplingBilinear.lua62
1 files changed, 36 insertions, 26 deletions
diff --git a/SpatialUpSamplingBilinear.lua b/SpatialUpSamplingBilinear.lua
index 9e1b3d6..54ce5b8 100644
--- a/SpatialUpSamplingBilinear.lua
+++ b/SpatialUpSamplingBilinear.lua
@@ -51,37 +51,40 @@ local function makeContiguous(self, input, gradOutput)
return input, gradOutput
end
-function SpatialUpSamplingBilinear:updateOutput(input)
- assert(input:dim() == 4 or input:dim()==3,
- 'SpatialUpSamplingBilinear only supports 3D or 4D tensors' )
- local inputwas3D = false
- if input:dim() == 3 then
- input=input:view(-1, input:size(1), input:size(2), input:size(3))
- inputwas3D = true
- end
- input = makeContiguous(self, input)
- assert(input:dim() == 4)
- -- Copy the input size
+function SpatialUpSamplingBilinear:setSize(input)
local xdim = input:dim()
- local ydim = input:dim() - 1
+ local ydim = xdim - 1
for i = 1, input:dim() do
self.inputSize[i] = input:size(i)
self.outputSize[i] = input:size(i)
end
if self.scale_factor ~= nil then
- self.outputSize[ydim] = (self.outputSize[ydim]-1) * (self.scale_factor-1)
- + self.outputSize[ydim]
- self.outputSize[xdim] = (self.outputSize[xdim]-1) * (self.scale_factor -1)
- + self.outputSize[xdim]
+ self.outputSize[ydim] = self.outputSize[ydim] * self.scale_factor
+ self.outputSize[xdim] = self.outputSize[xdim] * self.scale_factor
else
self.outputSize[ydim] = self.oheight
self.outputSize[xdim] = self.owidth
end
- -- Resize the output if needed
+end
+
+function SpatialUpSamplingBilinear:updateOutput(input)
+ assert(input:dim() == 4 or input:dim()==3,
+ 'SpatialUpSamplingBilinear only supports 3D or 4D tensors' )
+ local inputwas3D = false
+ if input:dim() == 3 then
+ input=input:view(-1, input:size(1), input:size(2), input:size(3))
+ inputwas3D = true
+ end
+ input = makeContiguous(self, input)
+ local xdim = input:dim()
+ local ydim = xdim - 1
+ self:setSize(input)
self.output:resize(self.outputSize)
input.THNN.SpatialUpSamplingBilinear_updateOutput(
input:cdata(),
- self.output:cdata()
+ self.output:cdata(),
+ self.outputSize[ydim],
+ self.outputSize[xdim]
)
if inputwas3D then
input = input:squeeze(1)
@@ -94,19 +97,26 @@ function SpatialUpSamplingBilinear:updateGradInput(input, gradOutput)
assert(input:dim() == 4 or input:dim()==3,
'SpatialUpSamplingBilinear only support 3D or 4D tensors' )
assert(input:dim() == gradOutput:dim(),
- 'Input and gradOutput should be of same dimension' )
+ 'Input and gradOutput should be of same dimension' )
local inputwas3D = false
if input:dim() == 3 then
- input=input:view(-1, input:size(1), input:size(2), input:size(3))
- gradOutput=gradOutput:view(-1, gradOutput:size(1), gradOutput:size(2),
- gradOutput:size(3))
+ input = input:view(-1, input:size(1), input:size(2), input:size(3))
+ gradOutput = gradOutput:view(-1, gradOutput:size(1), gradOutput:size(2),
+ gradOutput:size(3))
inputwas3D = true
end
- assert(input:dim() == 4 and gradOutput:dim() == 4)
- self.gradInput:resizeAs(input)
+ local xdim = input:dim()
+ local ydim = xdim - 1
+ self.gradInput:resizeAs(input)
input.THNN.SpatialUpSamplingBilinear_updateGradInput(
gradOutput:cdata(),
- self.gradInput:cdata()
+ self.gradInput:cdata(),
+ input:size(1),
+ input:size(2),
+ input:size(3),
+ input:size(4),
+ self.outputSize[ydim],
+ self.outputSize[xdim]
)
if inputwas3D then
input = input:squeeze(1)
@@ -126,4 +136,4 @@ function SpatialUpSamplingBilinear:__tostring__()
torch.type(self), self.oheight, self.owidth)
end
return s
-end \ No newline at end of file
+end