Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2015-09-11 03:59:13 +0300
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2015-09-11 03:59:13 +0300
commitfdf8c99c59959ebec7310c33ea185091f59bb818 (patch)
tree7a1e9d1883227d3419d38179eac37890a731988d
parent5f481a078c8cd4921af3f86a88204ca5f1d23f2d (diff)
support for parameter nodes
-rw-r--r--gmodule.lua27
1 files changed, 25 insertions, 2 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 41ae384..3aba8a3 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -85,8 +85,31 @@ function gModule:__init(inputs,outputs)
-- the complete graph is constructed
-- now regenerate the graphs with the additional nodes
- assert(#self.fg:roots() == 1, "expecting only one start")
- self.innode = self.fg:roots()[1]
+
+ local roots = self.fg:roots()
+ -- if there are more than one root in the forward graph, then make sure that
+ -- extra roots are parameter nodes
+ if #roots > 1 then
+ local innodeRoot = nil
+ -- first find our innode
+ for _, root in ipairs(roots) do
+ if root.data == innode.data then
+ assert(innodeRoot == nil, 'more than one matching input node found in leaves')
+ innodeRoot = root
+ else
+ assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')
+ assert(torch.typename(root.data.module) == 'nnop.Parameters',
+ 'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))
+ end
+ end
+ assert(innodeRoot ~= nil, 'input node not found among roots')
+ self.innode = innodeRoot
+ else
+ assert(#self.fg:roots() == 1, "expecting only one start")
+ self.innode = self.fg:roots()[1]
+ end
+
+
assert(self.innode.data == innode.data, "expecting the forward innode")
self.outnode = outnode
self.verbose = false