Welcome to mirror list, hosted at ThFree Co, Russian Federation.

CMul.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: ea3485a867f26396cfef7b36de1a93fdc7882087 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
local CMul, parent = torch.class('nn.CMul', 'nn.Module')

function CMul:__init(...)
   parent.__init(self)
   
   local arg = {...}

   self.size = torch.LongStorage()
   local n = #arg
   if n == 1 and torch.type(arg[1]) == 'torch.LongStorage' then
      self.size:resize(#arg[1]):copy(arg[1])
   else
      self.size:resize(n)
      for i=1,n do
         self.size[i] = arg[i]
      end
   end
  
   self.weight = torch.Tensor(self.size)
   self.gradWeight = torch.Tensor(self.size)
   
   self.output:resize(self.size) 

   self:reset()
end
 
function CMul:reset(stdv)
   if stdv then
      stdv = stdv * math.sqrt(3)
   else
      stdv = 1./math.sqrt(self.weight:nElement())
   end
   self.weight:uniform(-stdv,stdv)
end

function CMul:updateOutput(input)
   -- lazy-initialize
   self._output = self._output or input.new()
   self._weight = self._weight or input.new()
   self._expand = self._expand or input.new()
   self._repeat = self._repeat or input.new()
   
   self.output:resizeAs(input):copy(input)
   if input:nElement() == self.weight:nElement() then
      self._output:view(self.output, -1)
      self._weight:view(self.weight, -1)
      
      self._output:cmul(self._weight)
   else
      local batchSize = input:size(1)
      self._output:view(self.output, batchSize, -1)
      self._weight:view(self.weight, 1, -1)
      
      self._expand:expandAs(self._weight, self._output)
      
      if torch.type(input) == 'torch.CudaTensor' then
         self._repeat:resizeAs(self._expand):copy(self._expand)
         self._output:cmul(self._repeat)
      else
         self._output:cmul(self._expand)
      end
   end
   
   return self.output
end

function CMul:updateGradInput(input, gradOutput)
   if not self.gradInput then
      return
   end
   
   self._gradOutput = self._gradOutput or input.new()
   self._gradInput = self._gradInput or input.new()

   self.gradInput:resizeAs(input):zero()
   if self.weight:nElement() == gradOutput:nElement() then
      self.gradInput:addcmul(1, self.weight, gradOutput)
   else
      local batchSize = input:size(1)
      self._gradOutput:view(gradOutput, batchSize, -1)
      self._gradInput:view(self.gradInput, batchSize, -1)
      self._weight:view(self.weight, 1, -1)
      self._expand:expandAs(self._weight, self._gradOutput)
      
      if torch.type(input) == 'torch.CudaTensor' then
         self._repeat:resizeAs(self._expand):copy(self._expand)
         self._gradInput:addcmul(1, self._repeat, self._gradOutput)
      else
         self._gradInput:addcmul(1, self._expand, self._gradOutput)
      end
   end
   
   return self.gradInput
end

function CMul:accGradParameters(input, gradOutput, scale)
   scale = scale or 1
   
   self._input = self._input or input.new()
   self._gradWeight = self._gradWeight or input.new()
   self._sum = self._sum or input.new()
   
   if self.weight:nElement() == gradOutput:nElement() then
      self.gradWeight:addcmul(scale, input, gradOutput)
   else
      local batchSize = input:size(1)
      self._input:view(input, batchSize, -1)
      self._gradOutput:view(gradOutput, batchSize, -1)
      self._gradWeight:view(self.gradWeight, 1, -1)
      
      self._repeat:cmul(self._input, self._gradOutput)
      self._sum:sum(self._repeat, 1)
      self._gradWeight:add(scale, self._sum)
   end
end

function CMul:type(type)
   if type then
      self._input = nil
      self._output = nil
      self._weight = nil
      self._gradWeight = nil
      self._expand = nil
      self._repeat = nil
      self._sum = nil
   end
   return parent.type(self, type)
end