Welcome to mirror list, hosted at ThFree Co, Russian Federation.

VolumetricMaxUnpooling.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 6291f5b858741ef198c8b19edb7d09709cd19c9d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
local VolumetricMaxUnpooling, parent = torch.class('nn.VolumetricMaxUnpooling', 'nn.Module')

function VolumetricMaxUnpooling:__init(poolingModule)
  parent.__init(self)
  assert(torch.type(poolingModule)=='nn.VolumetricMaxPooling', 'Argument must be a nn.VolumetricMaxPooling module')
  assert(poolingModule.kT==poolingModule.dT and poolingModule.kH==poolingModule.dH and poolingModule.kW==poolingModule.dW, "The size of pooling module's kernel must be equal to its stride")
  self.pooling = poolingModule
end

function VolumetricMaxUnpooling:setParams()
  self.indices = self.pooling.indices
  self.otime = self.pooling.itime
  self.oheight = self.pooling.iheight
  self.owidth = self.pooling.iwidth
  self.dT = self.pooling.dT
  self.dH = self.pooling.dH
  self.dW = self.pooling.dW
  self.padT = self.pooling.padT
  self.padH = self.pooling.padH
  self.padW = self.pooling.padW
end

function VolumetricMaxUnpooling:updateOutput(input)
  self:setParams()
  input.THNN.VolumetricMaxUnpooling_updateOutput(
     input:cdata(),
     self.output:cdata(),
     self.indices:cdata(),
     self.otime, self.owidth, self.oheight,
     self.dT, self.dW, self.dH,
     self.padT, self.padW, self.padH
  )
  return self.output
end

function VolumetricMaxUnpooling:updateGradInput(input, gradOutput)
  self:setParams()
  input.THNN.VolumetricMaxUnpooling_updateGradInput(
     input:cdata(),
     gradOutput:cdata(),
     self.gradInput:cdata(),
     self.indices:cdata(),
     self.otime, self.owidth, self.oheight,
     self.dT, self.dW, self.dH,
     self.padT, self.padW, self.padH
  )
  return self.gradInput
end

function VolumetricMaxUnpooling:empty()
   self:clearState()
end

function VolumetricMaxUnpooling:__tostring__()
   return 'nn.VolumetricMaxUnpooling associated to '..tostring(self.pooling)
end