Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Pound <michael.pound@nottingham.ac.uk>2017-07-21 14:25:50 +0300
committerMichael Pound <michael.pound@nottingham.ac.uk>2017-07-21 14:25:50 +0300
commite3e71db6c064b69b2b3b0025537d4361c5ea93b1 (patch)
tree5be843e5a90d7209a98b62b06a2ae21a7063d583 /UpSampling.lua
parent200ae7d55a3381a232256223c0694498f8f51df0 (diff)
Added UpSampling module and associated tests.
Diffstat (limited to 'UpSampling.lua')
-rw-r--r--UpSampling.lua216
1 files changed, 216 insertions, 0 deletions
diff --git a/UpSampling.lua b/UpSampling.lua
new file mode 100644
index 0000000..9ad666f
--- /dev/null
+++ b/UpSampling.lua
@@ -0,0 +1,216 @@
+require 'nn.THNN'
+local UpSampling, parent =
+ torch.class('nn.UpSampling', 'nn.Module')
+
+--[[
+Upsamples a given 2D (spatial) or 3D (volumetric) input using either nearest neighbor, or linear
+interpolation.
+
+The input data is assumed to be of the form `minibatch x channels x [depth] x height x width`.
+Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
+
+The input parameter scale_factor specifies the amount of upsampling, and is assumed to be a positive
+integer. An optional mode parameter specifies either 'nearest' (the default) or 'linear'. Linear refers
+to either bilinear for spatial (4D) tensors, or trilinear for volumetric (5D) tensors.
+
+For nearest neighbour, output size will be:
+
+odepth = depth*scale_factor
+owidth = width*scale_factor
+oheight = height*scale_factor
+
+For linear interpolation:
+
+owidth = (width-1)*(scale_factor-1) + width
+owidth = (width-1)*(scale_factor-1) + width
+oheight = (height-1)*(scale_factor-1) + height
+
+Alternatively for bilinear or trilinear, [odepth], owidth and oheight can be directly provided as input
+--]]
+
+function UpSampling:__init(params, mode)
+ parent.__init(self)
+
+ -- Any ambigious mode will default to nearest
+ if mode ~= nil and (mode == 'linear' or mode == 'bilinear' or mode == 'trilinear') then
+ self.mode = 'linear'
+ else
+ self.mode = 'nearest'
+ end
+
+ self.odepth, self.owidth, self.oheight, self.scale_factor = nil, nil, nil, nil
+ if torch.type(params) == 'table' then
+ if self.mode == 'nearest' then
+ error ('Nearest neighbour upsampling requires a scale_factor')
+ end
+ self.odepth, self.owidth, self.oheight = params.odepth, params.owidth, params.oheight
+ if self.owidth == nil or self.oheight == nil then
+ error('Output height and width parameters are required')
+ end
+ else
+ self.scale_factor = params
+ if self.scale_factor < 1 then
+ error('scale_factor must be greater than 1')
+ end
+ if math.floor(self.scale_factor) ~= self.scale_factor then
+ error('scale_factor must be integer')
+ end
+ end
+
+ self.inputSize = torch.LongStorage(5):fill(0)
+ self.outputSize = torch.LongStorage(5):fill(0)
+end
+
+function UpSampling:setSize(input)
+ local xdim = input:dim()
+ local ydim = xdim - 1
+
+ local zdim = nil
+ if xdim > 4 then
+ zdim = xdim - 2
+ end
+
+ for i = 1, input:dim() do
+ self.inputSize[i] = input:size(i)
+ self.outputSize[i] = input:size(i)
+ end
+ if self.scale_factor ~= nil then
+ if zdim ~= nil then
+ self.outputSize[zdim] = self.outputSize[zdim] * self.scale_factor
+ end
+ self.outputSize[ydim] = self.outputSize[ydim] * self.scale_factor
+ self.outputSize[xdim] = self.outputSize[xdim] * self.scale_factor
+ else
+ if zdim ~= nil then
+ -- Runtime chech that depth was supplied given received 5D input
+ if self.odepth == nil then
+ error ('No output depth dimension was supplied for volumetric upsampling')
+ end
+ self.outputSize[zdim] = self.odepth
+ end
+ self.outputSize[ydim] = self.oheight
+ self.outputSize[xdim] = self.owidth
+ end
+end
+
+function UpSampling:updateOutput(input)
+ local nDim = input:dim()
+ if nDim < 4 or nDim > 5 then
+ error('UpSampling only supports 4D or 5D tensors')
+ end
+ local xdim = nDim
+ local ydim = xdim - 1
+ local zdim
+ if nDim == 5 then
+ zdim = xdim - 2
+ end
+ self:setSize(input)
+ if nDim == 4 then
+ if self.mode == 'nearest' then
+ input.THNN.SpatialUpSamplingNearest_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.scale_factor
+ )
+ else
+ input.THNN.SpatialUpSamplingBilinear_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.outputSize[ydim],
+ self.outputSize[xdim]
+ )
+ end
+ else
+ if self.mode == 'nearest' then
+ input.THNN.VolumetricUpSamplingNearest_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.scale_factor
+ )
+ else
+ input.THNN.VolumetricUpSamplingTrilinear_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.outputSize[zdim],
+ self.outputSize[ydim],
+ self.outputSize[xdim]
+ )
+ end
+ end
+ return self.output
+end
+
+function UpSampling:updateGradInput(input, gradOutput)
+ local nDim = input:dim()
+ if nDim < 4 or nDim > 5 then
+ error('UpSampling only supports 4D or 5D tensors')
+ end
+ if nDim ~= gradOutput:dim() then
+ error('Input and gradOutput should be of same dimension')
+ end
+ local xdim = nDim
+ local ydim = xdim - 1
+ local zdim
+ if nDim == 5 then
+ zdim = xdim - 2
+ end
+ self.gradInput:resizeAs(input)
+ if nDim == 4 then
+ if self.mode == 'nearest' then
+ input.THNN.SpatialUpSamplingNearest_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.scale_factor
+ )
+ else
+ input.THNN.SpatialUpSamplingBilinear_updateGradInput(
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ input:size(1),
+ input:size(2),
+ input:size(3),
+ input:size(4),
+ self.outputSize[ydim],
+ self.outputSize[xdim]
+ )
+ end
+ else
+ if self.mode == 'nearest' then
+ input.THNN.VolumetricUpSamplingNearest_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.scale_factor
+ )
+ else
+ input.THNN.VolumetricUpSamplingTrilinear_updateGradInput(
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ input:size(1),
+ input:size(2),
+ input:size(3),
+ input:size(4),
+ input:size(5),
+ self.outputSize[zdim],
+ self.outputSize[ydim],
+ self.outputSize[xdim]
+ )
+ end
+ end
+ return self.gradInput
+end
+
+function UpSampling:__tostring__()
+ local s
+ if self.scale_factor ~= nil then
+ s = string.format('%s(%dx, %s)', torch.type(self), self.scale_factor, self.mode)
+ else
+ if self.odepth ~= nil then
+ s = string.format('%s(%dx%dx%d, %s)', torch.type(self), self.odepth, self.oheight, self.owidth, self.mode)
+ else
+ s = string.format('%s(%dx%d, %s)', torch.type(self), self.oheight, self.owidth, self.mode)
+ end
+ end
+ return s
+end