Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndreas Köpf <andreas.koepf@xamla.com>2015-09-15 20:31:09 +0300
committerAndreas Köpf <andreas.koepf@xamla.com>2015-10-16 14:20:57 +0300
commit377afbc9be15ae98135bca2924bc65b72710348f (patch)
tree43686155ee73dcd9a88724ee842b6626a246c64c /gmodule.lua
parent7a35aecd72082dba6f7f8a1b3ec309840360d81a (diff)
Use nn.Container as base class for gModule
Added modules to container in ctor, removed redundant methods training(), evaluate(), share(), zeroGradParameters(), parameters(), clone() which are now provided by the base classes (nn.gModule -> nn.Container -> nn.Module).
Diffstat (limited to 'gmodule.lua')
-rw-r--r--gmodule.lua98
1 files changed, 15 insertions, 83 deletions
diff --git a/gmodule.lua b/gmodule.lua
index d6e17d7..a5c7d91 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -1,4 +1,3 @@
-
local nesting = paths.dofile('nesting.lua')
local utils = paths.dofile('utils.lua')
local istensor = torch.isTensor
@@ -40,7 +39,7 @@ end
--
-- The node.data.gradOutput holds the to-be-summed gradOutputs.
-- Each node has only one output. So we need only one gradOutput.
-local gModule, parent = torch.class('nn.gModule','nn.Module')
+local gModule, parent = torch.class('nn.gModule','nn.Container')
function gModule:__init(inputs,outputs)
parent.__init(self)
@@ -109,7 +108,6 @@ function gModule:__init(inputs,outputs)
self.innode = self.fg:roots()[1]
end
-
assert(self.innode.data == innode.data, "expecting the forward innode")
self.outnode = outnode
self.verbose = false
@@ -118,24 +116,23 @@ function gModule:__init(inputs,outputs)
-- computation on the graph is done through topsort of forward and backward graphs
self.forwardnodes = self.fg:topsort()
self.backwardnodes = self.bg:topsort()
- self.modules = {}
- for _, node in ipairs(self.forwardnodes) do
- if node.data.module then
- table.insert(self.modules, node.data.module)
+
+ -- iteratare over all nodes: check, tag and add to container
+ for i,node in ipairs(self.forwardnodes) do
+ -- check for unused inputs or unused split() outputs
+ if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then
+ local nUnused = node.data.nSplitOutputs - #node.children
+ error(string.format("%s of split(%s) outputs are unused", nUnused, node.data.nSplitOutputs))
end
- end
- -- Checking for unused inputs or unused split() outputs.
- for i,forwardNode in ipairs(self.forwardnodes) do
- if forwardNode.data.nSplitOutputs and forwardNode.data.nSplitOutputs ~= #forwardNode.children then
- local nUnused = forwardNode.data.nSplitOutputs - #forwardNode.children
- error(string.format("%s of split(%s) outputs are unused", nUnused,
- forwardNode.data.nSplitOutputs))
+
+ -- set data.forwardNodeId for node:label() output
+ node.data.forwardNodeId = node.id
+
+ -- add module to container
+ if node.data.module then
+ self:add(node.data.module)
end
end
- -- Adding data.forwardNodeId for nicer node:label() output.
- for i,forwardNode in ipairs(self.forwardnodes) do
- forwardNode.data.forwardNodeId = forwardNode.id
- end
self.output = nil
self.gradInput = nil
@@ -155,47 +152,6 @@ function gModule:map(gm, func)
end
end
-function gModule:clone(...)
- local f = torch.MemoryFile("rw"):binary()
- f:writeObject(self)
- f:seek(1)
- local clone = f:readObject()
- f:close()
- if select('#', ...) > 0 then
- clone:share(self, ...)
- end
- return clone
-end
-
-function gModule:share(gm, ...)
- local args = {...}
- self:map(gm,
- function(subnet1, subnet2)
- subnet1:share(subnet2, unpack(args))
- end)
- return self
-end
-
-function gModule:training()
- parent.training(self)
- for _, m in ipairs(self.modules) do
- m:training()
- end
-end
-
-function gModule:evaluate()
- parent.evaluate(self)
- for _, m in ipairs(self.modules) do
- m:evaluate()
- end
-end
-
-function gModule:applyToModules(func)
- for _, m in ipairs(self.modules) do
- func(m)
- end
-end
-
--[[ Recursively applies type(type_str) to any tensors in the argument. If the
argument is a tensor, type(type_str) is applied; if the argument is an array,
this function recurses into it. ]]
@@ -239,12 +195,6 @@ function gModule:type(type, tensorCache)
return self
end
-function gModule:zeroGradParameters()
- for _, m in ipairs(self.modules) do
- m:zeroGradParameters()
- end
-end
-
function gModule:updateOutput(input)
return self:runForwardFunction('updateOutput',input)
end
@@ -443,24 +393,6 @@ function gModule:accGradParameters(input,gradOutput,lr)
end
end
-function gModule:parameters()
- local p,gp = {},{}
- for _,node in ipairs(self.forwardnodes) do
- if node.data.module then
- local mp,mgp = node.data.module:parameters()
- if mp and mgp then
- for i = 1,#mp do
- table.insert(p,mp[i])
- table.insert(gp,mgp[i])
- end
- end
- end
- end
- return p,gp
-end
-
-
function gModule:__tostring__()
return self.name or torch.type(self)
end
-