Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Container.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 484a3be8e391bfb04457d79979fa0e033a3056a3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
-- This is code common to container modules, which are collections of
-- smaller constituent modules like Parallel, Sequential, etc.
local Container, parent = torch.class('nn.Container', 'nn.Module')

function Container:__init(...)
    parent.__init(self, ...)
    self.modules = {}
end

function Container:add(module)
    table.insert(self.modules, module)
    return self
end

function Container:get(index)
    return self.modules[index]
end

function Container:size()
    return #self.modules
end

function Container:zeroGradParameters()
    for i=1,#self.modules do
        self.modules[i]:zeroGradParameters()
    end
end

function Container:updateParameters(learningRate)
    for _,module in ipairs(self.modules) do
        module:updateParameters(learningRate)
    end
end

function Container:training()
    for i=1,#self.modules do
        self.modules[i]:training()
    end
end

function Container:evaluate()
    for i=1,#self.modules do
        self.modules[i]:evaluate()
    end
end

function Container:share(mlp, ...)
    for i=1,#self.modules do
        self.modules[i]:share(mlp.modules[i], ...);
    end
end

function Container:reset(stdv)
    for i=1,#self.modules do
        self.modules[i]:reset(stdv)
    end
end

function Container:parameters()
    local function tinsert(to, from)
        if type(from) == 'table' then
            for i=1,#from do
                tinsert(to,from[i])
            end
        else
            table.insert(to,from)
        end
    end
    local w = {}
    local gw = {}
    for i=1,#self.modules do
        local mw,mgw = self.modules[i]:parameters()
        if mw then
            tinsert(w,mw)
            tinsert(gw,mgw)
        end
    end
    return w,gw
end