Welcome to mirror list, hosted at ThFree Co, Russian Federation.

CAdd.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 1d7b45726cb15814f109e69b0365ad6edd09ec31 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
local CAdd, parent = torch.class("nn.CAdd", "nn.Module")

function CAdd:__init(...)
   parent.__init(self)

   local arg = {...}

   self.size = torch.LongStorage()
   local n = #arg
   if n == 1 and torch.type(arg[1]) == 'torch.LongStorage' then
      self.size:resize(#arg[1]):copy(arg[1])
   else
     self.size:resize(n)
     for i=1,n do
         self.size[i] = arg[i]
     end
   end

   self.bias = torch.Tensor(self.size)
   self.gradBias = torch.Tensor(self.size)

   self.output:resize(self.size)

   self:reset()
end

function CAdd:reset(stdv)
   if stdv then
      --std of uniform distribution on interval [-a,a] = a/sqrt(3)
      stdv = stdv * math.sqrt(3)
   else
      stdv = 1.0/math.sqrt(self.bias:nElement())
   end
   self.bias:uniform(-stdv,stdv)
end

function CAdd:updateOutput(input)
   self._output = self._output or input.new()
   self._bias = self._bias or input.new()
   self._expand = self._expand or input.new()
   self._repeat = self._repeat or input.new()

   self.output:resizeAs(input):copy(input)
   if input:nElement() == self.bias:nElement() then
      self.output:add(self.bias)
   else
      if self.bias:dim() == input:dim() then
         self._output:set(self.output)
         self._bias:set(self.bias)
      else
         local batchSize = input:size(1)
         self._output:view(self.output, batchSize, -1)
         self._bias:view(self.bias, 1, -1)
      end

      self._expand:expandAs(self._bias, self._output)

      --expandAs uses stride 0 and self._expand is not contiguous
      --cuda ops may assume contiguous input
      if torch.type(input) == 'torch.CudaTensor' then
         self._repeat:resizeAs(self._expand):copy(self._expand)
         self._output:add(self._repeat)
      else
         self._output:add(self._expand)
      end
   end

   return self.output
end

function CAdd:updateGradInput(input, gradOutput)
   self.gradInput = self.gradInput or input.new()
   self.gradInput:resizeAs(gradOutput):copy(gradOutput)

   return self.gradInput
end

function CAdd:accGradParameters(input, gradOutput, scale)
   scale = scale or 1

   self._gradBias = self._gradBias or gradOutput.new()
   self._gradOutput = self._gradOutput or gradOutput.new()
   self._repeat = self._repeat or gradOutput.new()

   if self.bias:nElement() == gradOutput:nElement() then
      self.gradBias:add(scale, gradOutput)
   else
      if self.bias:dim() == gradOutput:dim() then
         self._gradBias:set(self.gradBias)
         self._gradOutput:set(gradOutput)
      else
         local batchSize = input:size(1)
         self._gradBias:view(self.gradBias, 1, -1)
         self._gradOutput:view(gradOutput, batchSize, -1)
      end

      self._gradBias:expandAs(self._gradBias, self._gradOutput)

      --expandAs uses stride 0 and self._gradBias is not contiguous
      --cuda ops may assume contiguous input
      if torch.type(self._gradBias) == 'torch.CudaTensor' then
         self._repeat:resizeAs(self._gradBias):copy(self._gradBias)
         self._repeat:add(scale, self._gradOutput)
         self._gradBias:copy(self._repeat)
      else
         self._gradBias:add(scale, self._gradOutput)
      end
   end
end

function CAdd:type(type, tensorCache)
   if type then
      self:clearState()
   end
   return parent.type(self, type, tensorCache)
end

function CAdd:clearState()
   nn.utils.clear(self, {
      '_gradBias',
      '_expand',
      '_output',
      '_bias',
      '_repeat'
   })
   return parent.clearState(self)
end