diff options
author | Malcolm Reynolds <mareynolds@google.com> | 2016-01-20 17:34:47 +0300 |
---|---|---|
committer | Malcolm Reynolds <mareynolds@google.com> | 2016-01-20 17:34:47 +0300 |
commit | 5448b95c9ee02f28adc0ef424071023ff17e8aa4 (patch) | |
tree | 3565ce467b898b0d1f92f56a4cd493e40e9a973b | |
parent | 6cd17c40f4cbc426d20894a58d81360724253333 (diff) |
Store a reverse mapping when wiring together graph, detect unused nodes.
The connectivity checking code was previously unable to detect the
following error case:
local input = nn.Identity()()
local usedOutput = nn.Linear(20, 10)(input)
local unusedOutput = nn.Linear(20, 10)(input)
local gmod = nn.gModule({input}, {usedOutput})
With this fix, when gModule is called it will throw an error, because of
unusedOutput.
This is a backwards incompatible change, but I feel that the current
flexibility is error prone, and I can't see any advantage to it. We have
flushed out a couple of bugs in internal code with this change.
-rw-r--r-- | gmodule.lua | 23 | ||||
-rw-r--r-- | node.lua | 8 | ||||
-rw-r--r-- | test/test_connectivity.lua | 26 |
3 files changed, 57 insertions, 0 deletions
diff --git a/gmodule.lua b/gmodule.lua index 99698c0..910e2c4 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -130,6 +130,29 @@ function gModule:__init(inputs,outputs) debugLabel)) end + -- Check whether any nodes were defined as taking this node as an input, + -- but then left dangling and don't connect to the output. If this is + -- the case, then they won't be present in forwardnodes, so error out. + for successor, _ in pairs(node.data.reverseMap) do + local successorIsInGraph = false + + -- Only need to the part of forwardnodes from i onwards, topological + -- sort guarantees it cannot be in the first part. + for j = i+1, #self.forwardnodes do + -- Compare equality of data tables, as new Node objects have been + -- created by processes such as topoological sort, but the + -- underlying .data table is shared. + if self.forwardnodes[j].data == successor.data then + successorIsInGraph = true + break + end + end + local errStr = + "node declared on %s does not connect to gmodule output" + assert(successorIsInGraph, + string.format(errStr, successor.data.annotations._debugLabel)) + end + -- set data.forwardNodeId for node:label() output node.data.forwardNodeId = node.id @@ -11,6 +11,7 @@ function nnNode:__init(data) parent.__init(self,data) self.data.annotations = self.data.annotations or {} self.data.mapindex = self.data.mapindex or {} + self.data.reverseMap = self.data.reverseMap or {} if not self.data.annotations._debugLabel then self:_makeDebugLabel(debug.getinfo(6, 'Sl')) end @@ -39,6 +40,13 @@ function nnNode:add(child,domap) assert(not mapindex[data], "Don't pass the same input twice.") table.insert(mapindex,data) mapindex[data] = #mapindex + + -- The "child" that is added here actually represents the input node, + -- so we write into that node to indicate that we are downstream of it. + -- This enables dangling pointer detection. + local revMap = child.data.reverseMap + assert(not revMap[self], 'this connection has already been made!') + revMap[self] = true end end diff --git a/test/test_connectivity.lua b/test/test_connectivity.lua new file mode 100644 index 0000000..99b2539 --- /dev/null +++ b/test/test_connectivity.lua @@ -0,0 +1,26 @@ +local totem = require 'totem' +require 'nngraph' +local tests = totem.TestSuite() +local tester = totem.Tester() + +function tests.connectivity() + -- Store debug info here, need to call debug.getinfo on same line as the + -- dangling pointer is declared. + local dInfo + local input = nn.Identity()() + local lin = nn.Linear(20, 10)(input) + -- The Sigmoid does not connect to the output, so should cause an error + -- when we call gModule. + local dangling = nn.Sigmoid()(lin); dInfo = debug.getinfo(1, 'Sl') + local actualOutput = nn.Tanh()(lin) + local errStr = string.format( + 'node declared on %%[%s%%]:%d_ does not connect to gmodule output', + dInfo.short_src, dInfo.currentline) + tester:assertErrorPattern( + function() + return nn.gModule({input}, {actualOutput}) + end, + errStr) +end + +return tester:add(tests):run() |