Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SpatialUpSampling.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 779e979dc8ff054ee27e4ac5a5ee66e83e25d497 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
local SpatialUpSampling, parent = torch.class('nn.SpatialUpSampling', 'nn.Module')

local help_desc = [[
Applies a 2D up-sampling over an input image composed of
several input planes. The input tensor in forward(input) is
expected to be a 3D tensor (nInputPlane x width x height).
The number of output planes will be the same as nInputPlane.

The upsampling is done using the simple nearest neighbor
technique. For interpolated (bicubic) upsampling, use 
nn.SpatialReSampling().

If the input image is a 3D tensor nInputPlane x width x height,
the output image size will be nInputPlane x owidth x oheight where

owidth  = width*dW
oheight  = height*dH ]]

function SpatialUpSampling:__init(...)
   parent.__init(self)

   -- get args
   xlua.unpack_class(self, {...}, 'nn.SpatialUpSampling',  help_desc,
                     {arg='dW', type='number', help='stride width', req=true},
                     {arg='dH', type='number', help='stride height', req=true},
		     {arg='yDim', type='number', help='image y dimension', default=2},
		     {arg='xDim', type='number', help='image x dimension', default=3}
		  )
   if self.yDim+1 ~= self.xDim then
      error('nn.SpatialUpSampling: yDim must be equals to xDim-1')
   end
   self.outputSize = torch.LongStorage(4)
   self.inputSize = torch.LongStorage(4)
end

function SpatialUpSampling:updateOutput(input)
   self.inputSize:fill(1)
   for i = 1,self.yDim-1 do
      self.inputSize[1] = self.inputSize[1] * input:size(i)
   end
   self.inputSize[2] = input:size(self.yDim)
   self.inputSize[3] = input:size(self.xDim)
   for i = self.xDim+1,input:nDimension() do
      self.inputSize[4] = self.inputSize[4] * input:size(i)
   end
   self.outputSize[1] = self.inputSize[1]
   self.outputSize[2] = self.inputSize[2] * self.dH
   self.outputSize[3] = self.inputSize[3] * self.dW
   self.outputSize[4] = self.inputSize[4]
   self.output:resize(self.outputSize)
   input.nn.SpatialUpSampling_updateOutput(self, input:reshape(self.inputSize))
   local outputSize2 = input:size()
   outputSize2[self.yDim] = outputSize2[self.yDim] * self.dH
   outputSize2[self.xDim] = outputSize2[self.xDim] * self.dW
   self.output = self.output:reshape(outputSize2)
   return self.output
end

function SpatialUpSampling:updateGradInput(input, gradOutput)
   self.gradInput:resize(self.inputSize)
   input.nn.SpatialUpSampling_updateGradInput(self, input,
					      gradOutput:reshape(self.outputSize))
   self.gradInput = self.gradInput:reshape(input:size())
   return self.gradInput
end