diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-07-21 21:26:29 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-07-21 21:26:29 +0300 |
commit | 14cedef2d03dcbdd95e49be935c7368ed3d475c5 (patch) | |
tree | 5be843e5a90d7209a98b62b06a2ae21a7063d583 | |
parent | 200ae7d55a3381a232256223c0694498f8f51df0 (diff) | |
parent | e3e71db6c064b69b2b3b0025537d4361c5ea93b1 (diff) |
Merge pull request #1264 from mikepound/upsampling
Added UpSampling module and associated tests.
-rw-r--r-- | UpSampling.lua | 216 | ||||
-rw-r--r-- | doc/convolution.md | 35 | ||||
-rwxr-xr-x | init.lua | 1 | ||||
-rw-r--r-- | lib/THNN/doc/api_reference.md | 123 | ||||
-rwxr-xr-x | test.lua | 26 |
5 files changed, 401 insertions, 0 deletions
diff --git a/UpSampling.lua b/UpSampling.lua new file mode 100644 index 0000000..9ad666f --- /dev/null +++ b/UpSampling.lua @@ -0,0 +1,216 @@ +require 'nn.THNN'
+local UpSampling, parent =
+ torch.class('nn.UpSampling', 'nn.Module')
+
+--[[
+Upsamples a given 2D (spatial) or 3D (volumetric) input using either nearest neighbor, or linear
+interpolation.
+
+The input data is assumed to be of the form `minibatch x channels x [depth] x height x width`.
+Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
+
+The input parameter scale_factor specifies the amount of upsampling, and is assumed to be a positive
+integer. An optional mode parameter specifies either 'nearest' (the default) or 'linear'. Linear refers
+to either bilinear for spatial (4D) tensors, or trilinear for volumetric (5D) tensors.
+
+For nearest neighbour, output size will be:
+
+odepth = depth*scale_factor
+owidth = width*scale_factor
+oheight = height*scale_factor
+
+For linear interpolation:
+
+owidth = (width-1)*(scale_factor-1) + width
+owidth = (width-1)*(scale_factor-1) + width
+oheight = (height-1)*(scale_factor-1) + height
+
+Alternatively for bilinear or trilinear, [odepth], owidth and oheight can be directly provided as input
+--]]
+
+function UpSampling:__init(params, mode)
+ parent.__init(self)
+
+ -- Any ambigious mode will default to nearest
+ if mode ~= nil and (mode == 'linear' or mode == 'bilinear' or mode == 'trilinear') then
+ self.mode = 'linear'
+ else
+ self.mode = 'nearest'
+ end
+
+ self.odepth, self.owidth, self.oheight, self.scale_factor = nil, nil, nil, nil
+ if torch.type(params) == 'table' then
+ if self.mode == 'nearest' then
+ error ('Nearest neighbour upsampling requires a scale_factor')
+ end
+ self.odepth, self.owidth, self.oheight = params.odepth, params.owidth, params.oheight
+ if self.owidth == nil or self.oheight == nil then
+ error('Output height and width parameters are required')
+ end
+ else
+ self.scale_factor = params
+ if self.scale_factor < 1 then
+ error('scale_factor must be greater than 1')
+ end
+ if math.floor(self.scale_factor) ~= self.scale_factor then
+ error('scale_factor must be integer')
+ end
+ end
+
+ self.inputSize = torch.LongStorage(5):fill(0)
+ self.outputSize = torch.LongStorage(5):fill(0)
+end
+
+function UpSampling:setSize(input)
+ local xdim = input:dim()
+ local ydim = xdim - 1
+
+ local zdim = nil
+ if xdim > 4 then
+ zdim = xdim - 2
+ end
+
+ for i = 1, input:dim() do
+ self.inputSize[i] = input:size(i)
+ self.outputSize[i] = input:size(i)
+ end
+ if self.scale_factor ~= nil then
+ if zdim ~= nil then
+ self.outputSize[zdim] = self.outputSize[zdim] * self.scale_factor
+ end
+ self.outputSize[ydim] = self.outputSize[ydim] * self.scale_factor
+ self.outputSize[xdim] = self.outputSize[xdim] * self.scale_factor
+ else
+ if zdim ~= nil then
+ -- Runtime chech that depth was supplied given received 5D input
+ if self.odepth == nil then
+ error ('No output depth dimension was supplied for volumetric upsampling')
+ end
+ self.outputSize[zdim] = self.odepth
+ end
+ self.outputSize[ydim] = self.oheight
+ self.outputSize[xdim] = self.owidth
+ end
+end
+
+function UpSampling:updateOutput(input)
+ local nDim = input:dim()
+ if nDim < 4 or nDim > 5 then
+ error('UpSampling only supports 4D or 5D tensors')
+ end
+ local xdim = nDim
+ local ydim = xdim - 1
+ local zdim
+ if nDim == 5 then
+ zdim = xdim - 2
+ end
+ self:setSize(input)
+ if nDim == 4 then
+ if self.mode == 'nearest' then
+ input.THNN.SpatialUpSamplingNearest_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.scale_factor
+ )
+ else
+ input.THNN.SpatialUpSamplingBilinear_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.outputSize[ydim],
+ self.outputSize[xdim]
+ )
+ end
+ else
+ if self.mode == 'nearest' then
+ input.THNN.VolumetricUpSamplingNearest_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.scale_factor
+ )
+ else
+ input.THNN.VolumetricUpSamplingTrilinear_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.outputSize[zdim],
+ self.outputSize[ydim],
+ self.outputSize[xdim]
+ )
+ end
+ end
+ return self.output
+end
+
+function UpSampling:updateGradInput(input, gradOutput)
+ local nDim = input:dim()
+ if nDim < 4 or nDim > 5 then
+ error('UpSampling only supports 4D or 5D tensors')
+ end
+ if nDim ~= gradOutput:dim() then
+ error('Input and gradOutput should be of same dimension')
+ end
+ local xdim = nDim
+ local ydim = xdim - 1
+ local zdim
+ if nDim == 5 then
+ zdim = xdim - 2
+ end
+ self.gradInput:resizeAs(input)
+ if nDim == 4 then
+ if self.mode == 'nearest' then
+ input.THNN.SpatialUpSamplingNearest_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.scale_factor
+ )
+ else
+ input.THNN.SpatialUpSamplingBilinear_updateGradInput(
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ input:size(1),
+ input:size(2),
+ input:size(3),
+ input:size(4),
+ self.outputSize[ydim],
+ self.outputSize[xdim]
+ )
+ end
+ else
+ if self.mode == 'nearest' then
+ input.THNN.VolumetricUpSamplingNearest_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.scale_factor
+ )
+ else
+ input.THNN.VolumetricUpSamplingTrilinear_updateGradInput(
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ input:size(1),
+ input:size(2),
+ input:size(3),
+ input:size(4),
+ input:size(5),
+ self.outputSize[zdim],
+ self.outputSize[ydim],
+ self.outputSize[xdim]
+ )
+ end
+ end
+ return self.gradInput
+end
+
+function UpSampling:__tostring__()
+ local s
+ if self.scale_factor ~= nil then
+ s = string.format('%s(%dx, %s)', torch.type(self), self.scale_factor, self.mode)
+ else
+ if self.odepth ~= nil then
+ s = string.format('%s(%dx%dx%d, %s)', torch.type(self), self.odepth, self.oheight, self.owidth, self.mode)
+ else
+ s = string.format('%s(%dx%d, %s)', torch.type(self), self.oheight, self.owidth, self.mode)
+ end
+ end
+ return s
+end
diff --git a/doc/convolution.md b/doc/convolution.md index 82d890e..99b19b7 100644 --- a/doc/convolution.md +++ b/doc/convolution.md @@ -45,6 +45,7 @@ a kernel for computing the weighted average in a neighborhood ; * [VolumetricAveragePooling](#nn.VolumetricAveragePooling) : a 3D average-pooling operation over an input video. * [VolumetricMaxUnpooling](#nn.VolumetricMaxUnpooling) : a 3D max-unpooling operation. * [VolumetricReplicationPadding](#nn.VolumetricReplicationPadding) : Pads a volumetric feature map with the value at the edge of the input borders. ; + * [UpSampling](#nn.UpSampling): Upsampling for either spatial or volumetric inputs using nearest neighbor or linear interpolation. <a name="nn.TemporalModules"></a> @@ -1250,3 +1251,37 @@ module = nn.VolumetricReplicationPadding(padLeft, padRight, padTop, padBottom, ``` Each feature map of a given input is padded with the replication of the input boundary. + +<a name="nn.UpSampling"></a> +### UpSampling ### + +```lua +module = nn.UpSampling(scale, 'nearest') +module = nn.UpSampling(scale, 'linear') +module = nn.UpSampling({[odepth=D,] oheight=H, owidth=W}, 'linear') +``` + +Applies a 2D (spatial) or 3D (volumetric) up-sampling over an input image composed of several input planes. Available interpolation modes are nearest neighbor or linear (i.e. bilinear or trilinear depending on the input dimensions). The `input` tensor in `forward(input)` is expected to be of the form `minibatch x channels x [depth] x height x width`. I.e. for 4D input the final two dimensions will be upsampled, for 5D output the final three dimensions will be upsampled. The number of output planes will be the same. + +The parameters are the following: + * `scale`: The upscale ratio. Must be a positive integer. Required if using nearest neighbor. + * Or a table `{[odepth=D,] oheight=H, owidth=W}`: The required output depth, height and width, should be positive integers. + * `mode`: The method of interpolation, either `'nearest'` or `'linear'`. Default is `'nearest'` + +If `scale` is specified, given an input of depth iD, height iH and width iW, output depth, height and width will be, for nearest neighbor: + +```lua +oD = iD * scale +oH = iH * scale +oW = iW * scale +``` + +For linear interpolation: + +```lua +oD = (iD - 1)(scale - 1) + iD +oH = (iH - 1)(scale - 1) + iH +oW = (iW - 1)(scale - 1) + iW +``` + +There are no learnable parameters. @@ -147,6 +147,7 @@ require('nn.SpatialReplicationPadding') require('nn.SpatialUpSamplingNearest') require('nn.SpatialUpSamplingBilinear') require('nn.SpatialBatchNormalization') +require('nn.UpSampling') require('nn.VolumetricConvolution') require('nn.VolumetricFullConvolution') diff --git a/lib/THNN/doc/api_reference.md b/lib/THNN/doc/api_reference.md index 830cc3d..70c5c79 100644 --- a/lib/THNN/doc/api_reference.md +++ b/lib/THNN/doc/api_reference.md @@ -59,7 +59,10 @@ These are all modules implemented in THNN: * [SpatialMaxPooling](#spatialmaxpooling) * [SpatialMaxUnpooling](#spatialmaxunpooling) * [SpatialSubSampling](#spatialsubsampling) +* [SpatialReflectionPadding](#spatialreflectionpadding) +* [SpatialReplicationPadding](#spatialreplicationpadding) * [SpatialUpSamplingNearest](#spatialupsamplingnearest) +* [SpatialUpSamplingBilinear](#spatialupsamplingbilinear) * [Sqrt](#sqrt) * [Square](#square) * [Tanh](#tanh) @@ -70,6 +73,9 @@ These are all modules implemented in THNN: * [VolumetricFullConvolution](#volumetricfullconvolution) * [VolumetricMaxPooling](#volumetricmaxpooling) * [VolumetricMaxUnpooling](#volumetricmaxunpooling) +* [VolumetricReplicationPadding](#volumetricreplicationpadding) +* [VolumetricUpSamplingNearest](#volumetricupsamplingnearest) +* [VolumetricUpSamplingTrilinear](#volumetricupsamplingtrilinear) ## Abs ```C @@ -1254,6 +1260,42 @@ void THNN_SpatialSubSampling_accGradParameters( int dW, int dH, real scale); ``` +## SpatialReflectionPadding +```C +TH_API void THNN_(SpatialReflectionPadding_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + int pad_l, int pad_r, + int pad_t, int pad_b); +``` +```C +TH_API void THNN_(SpatialReflectionPadding_updateGradInput)( + THNNState *state, + THTensor *input, + THTensor *gradOutput, + THTensor *gradInput, + int pad_l, int pad_r, + int pad_t, int pad_b); +``` +## SpatialReplicationPadding +```C +TH_API void THNN_(SpatialReplicationPadding_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + int pad_l, int pad_r, + int pad_t, int pad_b); +``` +```C +TH_API void THNN_(SpatialReplicationPadding_updateGradInput)( + THNNState *state, + THTensor *input, + THTensor *gradOutput, + THTensor *gradInput, + int pad_l, int pad_r, + int pad_t, int pad_b); +``` ## SpatialUpSamplingNearest ```C void THNN_SpatialUpSamplingNearest_updateOutput( @@ -1270,6 +1312,27 @@ void THNN_SpatialUpSamplingNearest_updateGradInput( THTensor *gradInput, int scale_factor); ``` +## SpatialUpSamplingBilinear +```C +TH_API void THNN_(SpatialUpSamplingBilinear_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + int outputHeight, + int outputWidth); +``` +```C +TH_API void THNN_(SpatialUpSamplingBilinear_updateGradInput)( + THNNState *state, + THTensor *gradOutput, + THTensor *gradInput, + int nbatch, + int nchannels, + int inputHeight, + int inputWidth, + int outputHeight, + int outputWidth); +``` ## Sqrt ```C void THNN_Sqrt_updateOutput( @@ -1507,3 +1570,63 @@ void THNN_VolumetricMaxUnpooling_updateGradInput( int dT, int dW, int dH, int pT, int pW, int pH); ``` +## VolumetricReplicationPadding +```C +TH_API void THNN_(VolumetricReplicationPadding_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + int pleft, int pright, + int ptop, int pbottom, + int pfront, int pback); +``` +```C +TH_API void THNN_(VolumetricReplicationPadding_updateGradInput)( + THNNState *state, + THTensor *input, + THTensor *gradOutput, + THTensor *gradInput, + int pleft, int pright, + int ptop, int pbottom, + int pfront, int pback); +``` +## VolumetricUpSamplingNearest +```C +TH_API void THNN_(VolumetricUpSamplingNearest_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + int scale_factor); +``` +```C +TH_API void THNN_(VolumetricUpSamplingNearest_updateGradInput)( + THNNState *state, + THTensor *input, + THTensor *gradOutput, + THTensor *gradInput, + int scale_factor); +``` +## VolumetricUpSamplingTrilinear +```C +TH_API void THNN_(VolumetricUpSamplingTrilinear_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + int outputDepth, + int outputHeight, + int outputWidth); +``` +```C +TH_API void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)( + THNNState *state, + THTensor *gradOutput, + THTensor *gradInput, + int nbatch, + int nchannels, + int inputDepth, + int inputHeight, + int inputWidth, + int outputDepth, + int outputHeight, + int outputWidth); +``` @@ -6784,6 +6784,32 @@ function nntest.SpatialUpSamplingBilinear() end end +function nntest.UpSampling() + -- Test nearest and linear modes + for _,mode in pairs({'nearest','linear'}) do + for scale=2,4 do + for dim = 4,5 do + local m = nn.UpSampling(scale, mode) + + -- Create a randomly sized dimD vector + local shape = {} + for i = 1, dim do + table.insert(shape, torch.random(2, 4)) + end + + -- Check that the gradient is correct by using finite elements + local input = torch.Tensor(table.unpack(shape)):zero() + local err = jac.testJacobian(m, input) + mytester:assertlt(err, precision, ' error on state ') + + local ferr, berr = jac.testIO(m, input) + mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ') + end + end + end +end + function nntest.Concat() local input = torch.randn(4, 2) local num_modules = math.random(2, 5) |