blob: 9b599442f99879e9a30142d054dd4d396f90932b (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
local CMul, parent = torch.class('nn.CMul', 'nn.Module')
function CMul:__init(inputSize)
parent.__init(self)
self.weight = torch.Tensor(inputSize)
self.gradWeight = torch.Tensor(inputSize)
-- state
self.gradInput:resize(inputSize)
self.output:resize(inputSize)
self:reset()
end
function CMul:reset()
self.weight:fill(1)
end
function CMul:updateOutput(input)
self.output:copy(input);
self.output:cmul(self.weight);
return self.output
end
function CMul:updateGradInput(input, gradOutput)
if self.gradInput then
self.gradInput:zero()
self.gradInput:addcmul(1, self.weight, gradOutput)
return self.gradInput
end
end
function CMul:accGradParameters(input, gradOutput, scale)
self.gradWeight:addcmul(scale or 1, input, gradOutput)
end
|