Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMalcolm Reynolds <mareynolds@google.com>2015-11-19 20:43:25 +0300
committerMalcolm Reynolds <mareynolds@google.com>2015-11-19 20:43:25 +0300
commitf8a3fd2b7b9e413f2e7e0524ff29f2f89e9e2419 (patch)
tree2c0b361afc064844204f8d232934772d1bfb8e5f
parent233e126b08414b9e4fad022da02c6f3af3225493 (diff)
Make error messages clearer, disallow empty table in inputs.
This is intended to address the common class of errors I see where people make a mistake connecting up their modules, but the error message is either unclear, or doesn't point towards where the mistake actually is. The 'what is this in the input' is now explicit about what the problem is, and if people pass in a nn.Module (meaning they probably forgot a set of parentheses) instead of a nngraph.Node, we say this explicitly. The '1 of split(2) outputs unused' (which previously provided no information about which split was incorrect) now includes file / line number of both the place where the Node was constructed, and the place where :split() was called. Hopefully this should reduce debugging time drastically. Finally, I have disallow passing an empty table as the input connections, ie 'nn.Identity()({})' will error. I cannot see a use case for this (if you have no input connections, just leave the second parens empty). The risk of this is when people do 'nn.Identity()({variableWithTypo})', thinking they have made a connection but actually they haven't. This is likely to cause errors much later on, whereas with this commit it errors straight away. This *could* break existing code, but theres an easy to apply fix that needs to be done at each callsite. Koray has approved this restriction to the API, but I appreciate others may have a view here..
-rw-r--r--gmodule.lua25
-rw-r--r--init.lua14
-rw-r--r--node.lua13
-rw-r--r--test/test_debug.lua80
-rw-r--r--utils.lua17
5 files changed, 133 insertions, 16 deletions
diff --git a/gmodule.lua b/gmodule.lua
index a360d2f..65cf047 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -47,18 +47,19 @@ function gModule:__init(inputs,outputs)
-- we will define a dummy output node that connects all output modules
-- into itself. This will be the output for the forward graph and
-- input point for the backward graph
+ local node
local outnode = nngraph.Node({input={}})
- for i,n in ipairs(outputs) do
- if torch.typename(n) ~= 'nngraph.Node' then
- error(string.format('what is this in the outputs[%s]? %s',
- i, tostring(n)))
+ for i = 1, table.maxn(outputs) do
+ node = outputs[i]
+ if torch.typename(node) ~= 'nngraph.Node' then
+ error(utils.expectingNodeErrorMessage(node, 'outputs', i))
end
- outnode:add(n,true)
+ outnode:add(node, true)
end
- for i,n in ipairs(inputs) do
- if torch.typename(n) ~= 'nngraph.Node' then
- error(string.format('what is this in the inputs[%s]? %s',
- i, tostring(n)))
+ for i = 1, table.maxn(inputs) do
+ node = inputs[i]
+ if torch.typename(node) ~= 'nngraph.Node' then
+ error(utils.expectingNodeErrorMessage(node, 'inputs', i))
end
end
-- We add also a dummy input node.
@@ -122,7 +123,11 @@ function gModule:__init(inputs,outputs)
-- check for unused inputs or unused split() outputs
if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then
local nUnused = node.data.nSplitOutputs - #node.children
- error(string.format("%s of split(%s) outputs are unused", nUnused, node.data.nSplitOutputs))
+ local debugLabel = node.data.annotations._debugLabel
+ local errStr =
+ "%s of split(%s) outputs from the node declared at %s are unused"
+ error(string.format(errStr, nUnused, node.data.nSplitOutputs,
+ debugLabel))
end
-- set data.forwardNodeId for node:label() output
diff --git a/init.lua b/init.lua
index 9b117a6..8b45341 100644
--- a/init.lua
+++ b/init.lua
@@ -27,16 +27,24 @@ function Module:__call__(...)
local input = ...
if nArgs == 1 and input == nil then
- error('what is this in the input? nil')
+ error(utils.expectingNodeErrorMessage(input, 'inputs', 1))
+ end
+ -- Disallow passing empty table, in case someone passes a table with some
+ -- typo'd variable name in.
+ if type(input) == 'table' and next(input) == nil then
+ error('cannot pass an empty table of inputs. To indicate no incoming ' ..
+ 'connections, leave the second set of parens blank.')
end
if not istable(input) then
input = {input}
end
local mnode = nngraph.Node({module=self})
- for i,dnode in ipairs(input) do
+ local dnode
+ for i = 1, table.maxn(input) do
+ dnode = input[i]
if torch.typename(dnode) ~= 'nngraph.Node' then
- error('what is this in the input? ' .. tostring(dnode))
+ error(utils.expectingNodeErrorMessage(dnode, 'inputs', i))
end
mnode:add(dnode,true)
end
diff --git a/node.lua b/node.lua
index 0f2261f..3842605 100644
--- a/node.lua
+++ b/node.lua
@@ -20,8 +20,10 @@ end
making a graph.]]
function nnNode:_makeDebugLabel(dinfo)
if dinfo then
- self.data.annotations._debugLabel = string.format('[%s]:%d',
- dinfo.short_src, dinfo.currentline, dinfo.name)
+ self.data.annotations._debugLabel = string.format('[%s]:%d_%s',
+ dinfo.short_src,
+ dinfo.currentline,
+ dinfo.name or '')
end
end
@@ -46,7 +48,12 @@ end
function nnNode:split(noutput)
assert(noutput >= 2, "splitting to one output is not supported")
local debugLabel = self.data.annotations._debugLabel
- local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}})
+ -- Specify the source location where :split is called.
+ local dinfo = debug.getinfo(2, 'Sl')
+ local splitLoc = string.format(' split at [%s]:%d',
+ dinfo.short_src,
+ dinfo.currentline)
+ local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. splitLoc .. '-mnode'}})
mnode:add(self,true)
local selectnodes = {}
diff --git a/test/test_debug.lua b/test/test_debug.lua
new file mode 100644
index 0000000..f5c8b06
--- /dev/null
+++ b/test/test_debug.lua
@@ -0,0 +1,80 @@
+local totem = require 'totem'
+require 'nngraph'
+local tests = totem.TestSuite()
+local tester = totem.Tester()
+
+function tests.whatIsThisInTheInput()
+ tester:assertErrorPattern(
+ function()
+ local inp1, inp2 = nn.Identity()(), nn.Identity() -- missing 2nd parens
+ local lin = nn.Linear(20, 10)(nn.CMulTable(){inp1, inp2})
+ end,
+ 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' ..
+ 'only valid thing to pass is an instance of nngraph%.Node')
+
+ tester:assertErrorPattern(
+ function()
+ -- pass-through module, again with same mistake
+ local graphNode, nnModule = nn.Identity()(), nn.Identity()
+ return nn.gModule({graphNode, nnModule}, {graphNode})
+ end,
+ 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' ..
+ 'only valid thing to pass is an instance of nngraph%.Node')
+
+ tester:assertErrorPattern(
+ function()
+ local input = nn.Identity()()
+ local out1 = nn.Linear(20, 10)(input)
+ local out2 = nn.Sigmoid()(input)
+ local unconnectedOut = nn.Linear(2, 3)
+ return nn.gModule({input}, {out1, out2, unconnectedOut})
+ end,
+ 'outputs%[3%] is an nn%.Module, specifically a nn%.Linear, but the ' ..
+ 'only valid thing to pass is an instance of nngraph%.Node')
+
+ -- Check for detecting a nil in the middle of a table.
+ tester:assertErrorPattern(
+ function()
+ local input = nn.Identity()()
+ local out1 = nn.Tanh()(input)
+ local out2 = nn.Sigmoid()(input)
+ -- nil here is simulating a mis-spelt variable name
+ return nn.gModule({input}, {out1, nil, out2})
+ end,
+ 'outputs%[2%] is nil %(typo / bad index%?%)')
+
+ tester:assertErrorPattern(
+ function()
+ -- Typo variable name returns nil, meaning an empty table
+ local input = nn.Identity()({aNonExistentVariable})
+ end,
+ 'cannot pass an empty table of inputs%. To indicate no incoming ' ..
+ 'connections, leave the second set of parens blank%.')
+end
+
+function tests.splitUnused()
+ -- Need to do debuginfo on the same lines as the other code here to match
+ -- what debug.getinfo inside those calls will return
+ local dInfoDeclare, dInfoSplit
+ local input = nn.Identity()(); dInfoDeclare = debug.getinfo(1, 'Sl')
+ local output, unused = input:split(2); dInfoSplit = debug.getinfo(1, 'Sl')
+
+ local function willCrash()
+ return nn.gModule({input}, {output})
+ end
+
+ -- Work out what strings will be in the error message
+ local declareLoc = string.format('%%[%s%%]:%d_',
+ dInfoDeclare.short_src,
+ dInfoDeclare.currentline)
+ local splitLoc = string.format('%%[%s%%]:%d',
+ dInfoSplit.short_src,
+ dInfoSplit.currentline)
+
+ tester:assertErrorPattern(
+ willCrash,
+ '1 of split%(2%) outputs from the node declared at ' ..
+ declareLoc .. ' split at ' .. splitLoc .. '%-mnode are unused')
+end
+
+tester:add(tests):run()
diff --git a/utils.lua b/utils.lua
index c0bccb2..e2686f0 100644
--- a/utils.lua
+++ b/utils.lua
@@ -8,4 +8,21 @@ function utils.istable(x)
return type(x) == 'table' and not torch.typename(x)
end
+--[[ Returns a useful error message when a nngraph.Node is expected. ]]
+function utils.expectingNodeErrorMessage(badVal, array, idx)
+ if badVal == nil then
+ return string.format('%s[%d] is nil (typo / bad index?)', array, idx)
+ elseif torch.isTypeOf(badVal, 'nn.Module') then
+ local errStr = '%s[%d] is an nn.Module, specifically a %s, but the ' ..
+ 'only valid thing to pass is an instance of ' ..
+ 'nngraph.Node. Did you forget a second set of parens, ' ..
+ 'which convert a nn.Module to a nngraph.Node?'
+ return string.format(errStr, array, idx, torch.typename(badVal))
+ else
+ local errStr = '%s[%d] should be an nngraph.Node but is of type %s'
+ return string.format(errStr, array, idx,
+ torch.typename(badVal) or type(badVal))
+ end
+end
+
return utils