1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Module')
function ConcatTable:__init()
parent.__init(self)
self.modules = {}
self.output = {}
end
function ConcatTable:add(module)
table.insert(self.modules, module)
return self
end
function ConcatTable:get(index)
return self.modules[index]
end
function ConcatTable:size()
return #self.modules
end
function ConcatTable:updateOutput(input)
for i=1,#self.modules do
self.output[i] = self.modules[i]:updateOutput(input)
end
return self.output
end
function ConcatTable:updateGradInput(input, gradOutput)
for i,module in ipairs(self.modules) do
local currentGradInput = module:updateGradInput(input, gradOutput[i])
if i == 1 then
self.gradInput:resizeAs(currentGradInput):copy(currentGradInput)
else
self.gradInput:add(currentGradInput)
end
end
return self.gradInput
end
function ConcatTable:accGradParameters(input, gradOutput, scale)
scale = scale or 1
for i,module in ipairs(self.modules) do
module:accGradParameters(input, gradOutput[i], scale)
end
end
function ConcatTable:accUpdateGradParameters(input, gradOutput, lr)
for i,module in ipairs(self.modules) do
module:accUpdateGradParameters(input, gradOutput[i], lr)
end
end
function ConcatTable:zeroGradParameters()
for _,module in ipairs(self.modules) do
module:zeroGradParameters()
end
end
function ConcatTable:updateParameters(learningRate)
for _,module in ipairs(self.modules) do
module:updateParameters(learningRate)
end
end
function ConcatTable:share(mlp,...)
for i=1,#self.modules do
self.modules[i]:share(mlp.modules[i],...);
end
end
|