diff options
author | koray kavukcuoglu <koray@kavukcuoglu.org> | 2015-09-11 03:59:13 +0300 |
---|---|---|
committer | koray kavukcuoglu <koray@kavukcuoglu.org> | 2015-09-11 03:59:13 +0300 |
commit | fdf8c99c59959ebec7310c33ea185091f59bb818 (patch) | |
tree | 7a1e9d1883227d3419d38179eac37890a731988d | |
parent | 5f481a078c8cd4921af3f86a88204ca5f1d23f2d (diff) |
support for parameter nodes
-rw-r--r-- | gmodule.lua | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/gmodule.lua b/gmodule.lua index 41ae384..3aba8a3 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -85,8 +85,31 @@ function gModule:__init(inputs,outputs) -- the complete graph is constructed -- now regenerate the graphs with the additional nodes - assert(#self.fg:roots() == 1, "expecting only one start") - self.innode = self.fg:roots()[1] + + local roots = self.fg:roots() + -- if there are more than one root in the forward graph, then make sure that + -- extra roots are parameter nodes + if #roots > 1 then + local innodeRoot = nil + -- first find our innode + for _, root in ipairs(roots) do + if root.data == innode.data then + assert(innodeRoot == nil, 'more than one matching input node found in leaves') + innodeRoot = root + else + assert(root.data.module, 'Expected nnop.Parameters node, module not found in node') + assert(torch.typename(root.data.module) == 'nnop.Parameters', + 'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module)) + end + end + assert(innodeRoot ~= nil, 'input node not found among roots') + self.innode = innodeRoot + else + assert(#self.fg:roots() == 1, "expecting only one start") + self.innode = self.fg:roots()[1] + end + + assert(self.innode.data == innode.data, "expecting the forward innode") self.outnode = outnode self.verbose = false |