From 5448b95c9ee02f28adc0ef424071023ff17e8aa4 Mon Sep 17 00:00:00 2001 From: Malcolm Reynolds Date: Wed, 20 Jan 2016 14:34:47 +0000 Subject: Store a reverse mapping when wiring together graph, detect unused nodes. The connectivity checking code was previously unable to detect the following error case: local input = nn.Identity()() local usedOutput = nn.Linear(20, 10)(input) local unusedOutput = nn.Linear(20, 10)(input) local gmod = nn.gModule({input}, {usedOutput}) With this fix, when gModule is called it will throw an error, because of unusedOutput. This is a backwards incompatible change, but I feel that the current flexibility is error prone, and I can't see any advantage to it. We have flushed out a couple of bugs in internal code with this change. --- gmodule.lua | 23 +++++++++++++++++++++++ node.lua | 8 ++++++++ test/test_connectivity.lua | 26 ++++++++++++++++++++++++++ 3 files changed, 57 insertions(+) create mode 100644 test/test_connectivity.lua diff --git a/gmodule.lua b/gmodule.lua index 99698c0..910e2c4 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -130,6 +130,29 @@ function gModule:__init(inputs,outputs) debugLabel)) end + -- Check whether any nodes were defined as taking this node as an input, + -- but then left dangling and don't connect to the output. If this is + -- the case, then they won't be present in forwardnodes, so error out. + for successor, _ in pairs(node.data.reverseMap) do + local successorIsInGraph = false + + -- Only need to the part of forwardnodes from i onwards, topological + -- sort guarantees it cannot be in the first part. + for j = i+1, #self.forwardnodes do + -- Compare equality of data tables, as new Node objects have been + -- created by processes such as topoological sort, but the + -- underlying .data table is shared. + if self.forwardnodes[j].data == successor.data then + successorIsInGraph = true + break + end + end + local errStr = + "node declared on %s does not connect to gmodule output" + assert(successorIsInGraph, + string.format(errStr, successor.data.annotations._debugLabel)) + end + -- set data.forwardNodeId for node:label() output node.data.forwardNodeId = node.id diff --git a/node.lua b/node.lua index a55aa48..93f35bc 100644 --- a/node.lua +++ b/node.lua @@ -11,6 +11,7 @@ function nnNode:__init(data) parent.__init(self,data) self.data.annotations = self.data.annotations or {} self.data.mapindex = self.data.mapindex or {} + self.data.reverseMap = self.data.reverseMap or {} if not self.data.annotations._debugLabel then self:_makeDebugLabel(debug.getinfo(6, 'Sl')) end @@ -39,6 +40,13 @@ function nnNode:add(child,domap) assert(not mapindex[data], "Don't pass the same input twice.") table.insert(mapindex,data) mapindex[data] = #mapindex + + -- The "child" that is added here actually represents the input node, + -- so we write into that node to indicate that we are downstream of it. + -- This enables dangling pointer detection. + local revMap = child.data.reverseMap + assert(not revMap[self], 'this connection has already been made!') + revMap[self] = true end end diff --git a/test/test_connectivity.lua b/test/test_connectivity.lua new file mode 100644 index 0000000..99b2539 --- /dev/null +++ b/test/test_connectivity.lua @@ -0,0 +1,26 @@ +local totem = require 'totem' +require 'nngraph' +local tests = totem.TestSuite() +local tester = totem.Tester() + +function tests.connectivity() + -- Store debug info here, need to call debug.getinfo on same line as the + -- dangling pointer is declared. + local dInfo + local input = nn.Identity()() + local lin = nn.Linear(20, 10)(input) + -- The Sigmoid does not connect to the output, so should cause an error + -- when we call gModule. + local dangling = nn.Sigmoid()(lin); dInfo = debug.getinfo(1, 'Sl') + local actualOutput = nn.Tanh()(lin) + local errStr = string.format( + 'node declared on %%[%s%%]:%d_ does not connect to gmodule output', + dInfo.short_src, dInfo.currentline) + tester:assertErrorPattern( + function() + return nn.gModule({input}, {actualOutput}) + end, + errStr) +end + +return tester:add(tests):run() -- cgit v1.2.3