Welcome to mirror list, hosted at ThFree Co, Russian Federation.

CMul.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 9b599442f99879e9a30142d054dd4d396f90932b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
local CMul, parent = torch.class('nn.CMul', 'nn.Module')

function CMul:__init(inputSize)
   parent.__init(self)
  
   self.weight = torch.Tensor(inputSize)
   self.gradWeight = torch.Tensor(inputSize)
   
   -- state
   self.gradInput:resize(inputSize)
   self.output:resize(inputSize) 

   self:reset()
end
 
function CMul:reset()
   self.weight:fill(1)
end

function CMul:updateOutput(input)
   self.output:copy(input);
   self.output:cmul(self.weight);
   return self.output
end

function CMul:updateGradInput(input, gradOutput)
   if self.gradInput then
      self.gradInput:zero()
      self.gradInput:addcmul(1, self.weight, gradOutput)
      return self.gradInput
   end
end

function CMul:accGradParameters(input, gradOutput, scale)
   self.gradWeight:addcmul(scale or 1, input, gradOutput)
end