Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SpatialDownSampling.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 2aa4216b33cb1e60ed503a848b065846a3481bba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
local SpatialDownSampling, parent = torch.class('nn.SpatialDownSampling', 'nn.Module')

local help_desc = [[
Applies a 2D down-sampling over an input image composed of
several input planes. The input tensor in forward(input) is
expected to be a 3D tensor (nInputPlane x width x height).
The number of output planes will be the same as nInputPlane.

The downsampling is done using the simple average
technique. For interpolated (bicubic) downsampling, use 
nn.SpatialReSampling().

If the input image is a 3D tensor nInputPlane x width x height,
the output image size will be nInputPlane x owidth x oheight where

owidth  = floor(width/rW)
oheight  = floor(height/rH) ]]

function SpatialDownSampling:__init(...)
   parent.__init(self)

   -- get args
   xlua.unpack_class(self, {...}, 'nn.SpatialDownSampling',  help_desc,
                     {arg='rW', type='number', help='ratio width', req=true},
                     {arg='rH', type='number', help='ratio height', req=true})
end

function SpatialDownSampling:updateOutput(input)
   if (input:size(2) / self.rH) < 1 then
      error('input too small in dimension 2')
   elseif (input:size(3) / self.rW) < 1 then
      error('input too small in dimension 3')
   end
   self.output:resize(input:size(1), math.floor(input:size(2) / self.rH),
		      math.floor(input:size(3) / self.rW))
   input.nn.SpatialDownSampling_updateOutput(self, input)
   return self.output
end

function SpatialDownSampling:updateGradInput(input, gradOutput)
   self.gradInput:resizeAs(input)
   input.nn.SpatialDownSampling_updateGradInput(self, gradOutput)
   return self.gradInput
end