Welcome to mirror list, hosted at ThFree Co, Russian Federation.

CMul.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 22b5b7483c38e29b040664a08408148d53dc6315 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
local CMul, parent = torch.class('nn.CMul', 'nn.Module')

function CMul:__init(...)
   parent.__init(self)
   
   local arg = {...}

   self.size = torch.LongStorage()
   local n = #arg
   if n == 1 and torch.type(arg[1]) == 'torch.LongStorage' then
      self.size:resize(#arg[1]):copy(arg[1])
   else
      self.size:resize(n)
      for i=1,n do
         self.size[i] = arg[i]
      end
   end
  
   self.weight = torch.Tensor(self.size)
   self.gradWeight = torch.Tensor(self.size)
   
   self.output:resize(self.size) 

   self:reset()
end
 
function CMul:reset(stdv)
   if stdv then
      stdv = stdv * math.sqrt(3)
   else
      stdv = 1./math.sqrt(self.weight:nElement())
   end
   self.weight:uniform(-stdv,stdv)
end

function CMul:updateOutput(input)
   -- lazy-initialize
   self._output = self._output or input.new()
   self._weight = self._weight or input.new()
   self._expand = self._expand or input.new()
   self._repeat = self._repeat or input.new()
   
   self.output:resizeAs(input):copy(input)
   if input:nElement() == self.weight:nElement() then
      self._output:view(self.output, -1)
      self._weight:view(self.weight, -1)
      
      self._output:cmul(self._weight)
   else
      if self.weight:dim() == input:dim() then
         self._output:set(self.output)
         self._weight:set(self.weight)
      else
         local batchSize = input:size(1)
         self._output:view(self.output, batchSize, -1)
         self._weight:view(self.weight, 1, -1)
      end

      self._expand:expandAs(self._weight, self._output)
      
      if torch.type(input) == 'torch.CudaTensor' then
         self._repeat:resizeAs(self._expand):copy(self._expand)
         self._output:cmul(self._repeat)
      else
         self._output:cmul(self._expand)
      end
   end
   
   return self.output
end

function CMul:updateGradInput(input, gradOutput)
   if not self.gradInput then
      return
   end
   
   self._gradOutput = self._gradOutput or input.new()
   self._gradInput = self._gradInput or input.new()

   self.gradInput:resizeAs(input):zero()
   if self.weight:nElement() == gradOutput:nElement() then
      self.gradInput:addcmul(1, self.weight, gradOutput)
   else
      if self.weight:dim() == input:dim() then
         nn.utils.contiguousView(self._gradOutput, gradOutput, gradOutput:size())
         nn.utils.contiguousView(self._gradInput, self.gradInput, self.gradInput:size())
         self._weight:set(self.weight)
      else
         local batchSize = input:size(1)
         nn.utils.contiguousView(self._gradOutput, gradOutput, batchSize, -1)
         nn.utils.contiguousView(self._gradInput, self.gradInput, batchSize, -1)
         self._weight:view(self.weight, 1, -1)
      end

      self._expand:expandAs(self._weight, self._gradOutput)
      
      if torch.type(input) == 'torch.CudaTensor' then
         self._repeat:resizeAs(self._expand):copy(self._expand)
         self._gradInput:addcmul(1, self._repeat, self._gradOutput)
      else
         self._gradInput:addcmul(1, self._expand, self._gradOutput)
      end
   end
   
   return self.gradInput
end

function CMul:accGradParameters(input, gradOutput, scale)
   scale = scale or 1
   
   self._input = self._input or input.new()
   self._gradWeight = self._gradWeight or input.new()
   self._sum = self._sum or input.new()
   
   if self.weight:nElement() == gradOutput:nElement() then
      self.gradWeight:addcmul(scale, input, gradOutput)
   else
      if self.weight:dim() == input:dim() then
         nn.utils.contiguousView(self._input, input, input:size())
         nn.utils.contiguousView(self._gradOutput, gradOutput, gradOutput:size())
         self._gradWeight:set(self.gradWeight)
      
         self._repeat:cmul(self._input, self._gradOutput)
         local sumInto = self._sum
         local sumFrom = self._repeat
         for i=1,self.weight:dim() do
            if self.weight:size(i) ~= input:size(i) then
               sumInto:sum(sumFrom, i)
               sumInto = sumFrom
               sumFrom = sumFrom == self._repeat and self._sum or self._repeat
            end
         end
         self._gradWeight:add(scale, sumFrom)
      else
         local batchSize = input:size(1)
         nn.utils.contiguousView(self._input, input, batchSize, -1)
         nn.utils.contiguousView(self._gradOutput, gradOutput, batchSize, -1)
         self._gradWeight:view(self.gradWeight, 1, -1)

         self._repeat:cmul(self._input, self._gradOutput)
         self._sum:sum(self._repeat, 1)
         self._gradWeight:add(scale, self._sum)
      end

   end
end

function CMul:type(type, tensorCache)
   if type then
      self:clearState()
   end
   return parent.type(self, type, tensorCache)
end

function CMul:clearState()
   nn.utils.clear(self, {
      '_input',
      '_output',
      '_weight',
      '_gradWeight',
      '_expand',
      '_repeat',
      '_sum',
   })
   return parent.clearState(self)
end