Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMalcolm Reynolds <mareynolds@google.com>2016-01-20 17:34:47 +0300
committerMalcolm Reynolds <mareynolds@google.com>2016-01-20 17:34:47 +0300
commit5448b95c9ee02f28adc0ef424071023ff17e8aa4 (patch)
tree3565ce467b898b0d1f92f56a4cd493e40e9a973b
parent6cd17c40f4cbc426d20894a58d81360724253333 (diff)
Store a reverse mapping when wiring together graph, detect unused nodes.
The connectivity checking code was previously unable to detect the following error case: local input = nn.Identity()() local usedOutput = nn.Linear(20, 10)(input) local unusedOutput = nn.Linear(20, 10)(input) local gmod = nn.gModule({input}, {usedOutput}) With this fix, when gModule is called it will throw an error, because of unusedOutput. This is a backwards incompatible change, but I feel that the current flexibility is error prone, and I can't see any advantage to it. We have flushed out a couple of bugs in internal code with this change.
-rw-r--r--gmodule.lua23
-rw-r--r--node.lua8
-rw-r--r--test/test_connectivity.lua26
3 files changed, 57 insertions, 0 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 99698c0..910e2c4 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -130,6 +130,29 @@ function gModule:__init(inputs,outputs)
debugLabel))
end
+ -- Check whether any nodes were defined as taking this node as an input,
+ -- but then left dangling and don't connect to the output. If this is
+ -- the case, then they won't be present in forwardnodes, so error out.
+ for successor, _ in pairs(node.data.reverseMap) do
+ local successorIsInGraph = false
+
+ -- Only need to the part of forwardnodes from i onwards, topological
+ -- sort guarantees it cannot be in the first part.
+ for j = i+1, #self.forwardnodes do
+ -- Compare equality of data tables, as new Node objects have been
+ -- created by processes such as topoological sort, but the
+ -- underlying .data table is shared.
+ if self.forwardnodes[j].data == successor.data then
+ successorIsInGraph = true
+ break
+ end
+ end
+ local errStr =
+ "node declared on %s does not connect to gmodule output"
+ assert(successorIsInGraph,
+ string.format(errStr, successor.data.annotations._debugLabel))
+ end
+
-- set data.forwardNodeId for node:label() output
node.data.forwardNodeId = node.id
diff --git a/node.lua b/node.lua
index a55aa48..93f35bc 100644
--- a/node.lua
+++ b/node.lua
@@ -11,6 +11,7 @@ function nnNode:__init(data)
parent.__init(self,data)
self.data.annotations = self.data.annotations or {}
self.data.mapindex = self.data.mapindex or {}
+ self.data.reverseMap = self.data.reverseMap or {}
if not self.data.annotations._debugLabel then
self:_makeDebugLabel(debug.getinfo(6, 'Sl'))
end
@@ -39,6 +40,13 @@ function nnNode:add(child,domap)
assert(not mapindex[data], "Don't pass the same input twice.")
table.insert(mapindex,data)
mapindex[data] = #mapindex
+
+ -- The "child" that is added here actually represents the input node,
+ -- so we write into that node to indicate that we are downstream of it.
+ -- This enables dangling pointer detection.
+ local revMap = child.data.reverseMap
+ assert(not revMap[self], 'this connection has already been made!')
+ revMap[self] = true
end
end
diff --git a/test/test_connectivity.lua b/test/test_connectivity.lua
new file mode 100644
index 0000000..99b2539
--- /dev/null
+++ b/test/test_connectivity.lua
@@ -0,0 +1,26 @@
+local totem = require 'totem'
+require 'nngraph'
+local tests = totem.TestSuite()
+local tester = totem.Tester()
+
+function tests.connectivity()
+ -- Store debug info here, need to call debug.getinfo on same line as the
+ -- dangling pointer is declared.
+ local dInfo
+ local input = nn.Identity()()
+ local lin = nn.Linear(20, 10)(input)
+ -- The Sigmoid does not connect to the output, so should cause an error
+ -- when we call gModule.
+ local dangling = nn.Sigmoid()(lin); dInfo = debug.getinfo(1, 'Sl')
+ local actualOutput = nn.Tanh()(lin)
+ local errStr = string.format(
+ 'node declared on %%[%s%%]:%d_ does not connect to gmodule output',
+ dInfo.short_src, dInfo.currentline)
+ tester:assertErrorPattern(
+ function()
+ return nn.gModule({input}, {actualOutput})
+ end,
+ errStr)
+end
+
+return tester:add(tests):run()