diff options
author | Andreas Köpf <andreas.koepf@xamla.com> | 2015-09-15 20:31:09 +0300 |
---|---|---|
committer | Andreas Köpf <andreas.koepf@xamla.com> | 2015-10-16 14:20:57 +0300 |
commit | 377afbc9be15ae98135bca2924bc65b72710348f (patch) | |
tree | 43686155ee73dcd9a88724ee842b6626a246c64c /gmodule.lua | |
parent | 7a35aecd72082dba6f7f8a1b3ec309840360d81a (diff) |
Use nn.Container as base class for gModule
Added modules to container in ctor, removed redundant methods training(),
evaluate(), share(), zeroGradParameters(), parameters(), clone() which are
now provided by the base classes (nn.gModule -> nn.Container -> nn.Module).
Diffstat (limited to 'gmodule.lua')
-rw-r--r-- | gmodule.lua | 98 |
1 files changed, 15 insertions, 83 deletions
diff --git a/gmodule.lua b/gmodule.lua index d6e17d7..a5c7d91 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -1,4 +1,3 @@ - local nesting = paths.dofile('nesting.lua') local utils = paths.dofile('utils.lua') local istensor = torch.isTensor @@ -40,7 +39,7 @@ end -- -- The node.data.gradOutput holds the to-be-summed gradOutputs. -- Each node has only one output. So we need only one gradOutput. -local gModule, parent = torch.class('nn.gModule','nn.Module') +local gModule, parent = torch.class('nn.gModule','nn.Container') function gModule:__init(inputs,outputs) parent.__init(self) @@ -109,7 +108,6 @@ function gModule:__init(inputs,outputs) self.innode = self.fg:roots()[1] end - assert(self.innode.data == innode.data, "expecting the forward innode") self.outnode = outnode self.verbose = false @@ -118,24 +116,23 @@ function gModule:__init(inputs,outputs) -- computation on the graph is done through topsort of forward and backward graphs self.forwardnodes = self.fg:topsort() self.backwardnodes = self.bg:topsort() - self.modules = {} - for _, node in ipairs(self.forwardnodes) do - if node.data.module then - table.insert(self.modules, node.data.module) + + -- iteratare over all nodes: check, tag and add to container + for i,node in ipairs(self.forwardnodes) do + -- check for unused inputs or unused split() outputs + if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then + local nUnused = node.data.nSplitOutputs - #node.children + error(string.format("%s of split(%s) outputs are unused", nUnused, node.data.nSplitOutputs)) end - end - -- Checking for unused inputs or unused split() outputs. - for i,forwardNode in ipairs(self.forwardnodes) do - if forwardNode.data.nSplitOutputs and forwardNode.data.nSplitOutputs ~= #forwardNode.children then - local nUnused = forwardNode.data.nSplitOutputs - #forwardNode.children - error(string.format("%s of split(%s) outputs are unused", nUnused, - forwardNode.data.nSplitOutputs)) + + -- set data.forwardNodeId for node:label() output + node.data.forwardNodeId = node.id + + -- add module to container + if node.data.module then + self:add(node.data.module) end end - -- Adding data.forwardNodeId for nicer node:label() output. - for i,forwardNode in ipairs(self.forwardnodes) do - forwardNode.data.forwardNodeId = forwardNode.id - end self.output = nil self.gradInput = nil @@ -155,47 +152,6 @@ function gModule:map(gm, func) end end -function gModule:clone(...) - local f = torch.MemoryFile("rw"):binary() - f:writeObject(self) - f:seek(1) - local clone = f:readObject() - f:close() - if select('#', ...) > 0 then - clone:share(self, ...) - end - return clone -end - -function gModule:share(gm, ...) - local args = {...} - self:map(gm, - function(subnet1, subnet2) - subnet1:share(subnet2, unpack(args)) - end) - return self -end - -function gModule:training() - parent.training(self) - for _, m in ipairs(self.modules) do - m:training() - end -end - -function gModule:evaluate() - parent.evaluate(self) - for _, m in ipairs(self.modules) do - m:evaluate() - end -end - -function gModule:applyToModules(func) - for _, m in ipairs(self.modules) do - func(m) - end -end - --[[ Recursively applies type(type_str) to any tensors in the argument. If the argument is a tensor, type(type_str) is applied; if the argument is an array, this function recurses into it. ]] @@ -239,12 +195,6 @@ function gModule:type(type, tensorCache) return self end -function gModule:zeroGradParameters() - for _, m in ipairs(self.modules) do - m:zeroGradParameters() - end -end - function gModule:updateOutput(input) return self:runForwardFunction('updateOutput',input) end @@ -443,24 +393,6 @@ function gModule:accGradParameters(input,gradOutput,lr) end end -function gModule:parameters() - local p,gp = {},{} - for _,node in ipairs(self.forwardnodes) do - if node.data.module then - local mp,mgp = node.data.module:parameters() - if mp and mgp then - for i = 1,#mp do - table.insert(p,mp[i]) - table.insert(gp,mgp[i]) - end - end - end - end - return p,gp -end - - function gModule:__tostring__() return self.name or torch.type(self) end - |