diff options
author | Malcolm Reynolds <mareynolds@google.com> | 2015-11-19 20:43:25 +0300 |
---|---|---|
committer | Malcolm Reynolds <mareynolds@google.com> | 2015-11-19 20:43:25 +0300 |
commit | f8a3fd2b7b9e413f2e7e0524ff29f2f89e9e2419 (patch) | |
tree | 2c0b361afc064844204f8d232934772d1bfb8e5f | |
parent | 233e126b08414b9e4fad022da02c6f3af3225493 (diff) |
Make error messages clearer, disallow empty table in inputs.
This is intended to address the common class of errors I see where
people make a mistake connecting up their modules, but the error message
is either unclear, or doesn't point towards where the mistake actually
is.
The 'what is this in the input' is now explicit about what the problem
is, and if people pass in a nn.Module (meaning they probably forgot a
set of parentheses) instead of a nngraph.Node, we say this explicitly.
The '1 of split(2) outputs unused' (which previously provided no
information about which split was incorrect) now includes file / line
number of both the place where the Node was constructed, and the place
where :split() was called. Hopefully this should reduce debugging time
drastically.
Finally, I have disallow passing an empty table as the input
connections, ie 'nn.Identity()({})' will error. I cannot see a use case
for this (if you have no input connections, just leave the second parens
empty). The risk of this is when people do
'nn.Identity()({variableWithTypo})', thinking they have made a
connection but actually they haven't. This is likely to cause errors
much later on, whereas with this commit it errors straight away.
This *could* break existing code, but theres an easy to apply fix that
needs to be done at each callsite. Koray has approved this restriction
to the API, but I appreciate others may have a view here..
-rw-r--r-- | gmodule.lua | 25 | ||||
-rw-r--r-- | init.lua | 14 | ||||
-rw-r--r-- | node.lua | 13 | ||||
-rw-r--r-- | test/test_debug.lua | 80 | ||||
-rw-r--r-- | utils.lua | 17 |
5 files changed, 133 insertions, 16 deletions
diff --git a/gmodule.lua b/gmodule.lua index a360d2f..65cf047 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -47,18 +47,19 @@ function gModule:__init(inputs,outputs) -- we will define a dummy output node that connects all output modules -- into itself. This will be the output for the forward graph and -- input point for the backward graph + local node local outnode = nngraph.Node({input={}}) - for i,n in ipairs(outputs) do - if torch.typename(n) ~= 'nngraph.Node' then - error(string.format('what is this in the outputs[%s]? %s', - i, tostring(n))) + for i = 1, table.maxn(outputs) do + node = outputs[i] + if torch.typename(node) ~= 'nngraph.Node' then + error(utils.expectingNodeErrorMessage(node, 'outputs', i)) end - outnode:add(n,true) + outnode:add(node, true) end - for i,n in ipairs(inputs) do - if torch.typename(n) ~= 'nngraph.Node' then - error(string.format('what is this in the inputs[%s]? %s', - i, tostring(n))) + for i = 1, table.maxn(inputs) do + node = inputs[i] + if torch.typename(node) ~= 'nngraph.Node' then + error(utils.expectingNodeErrorMessage(node, 'inputs', i)) end end -- We add also a dummy input node. @@ -122,7 +123,11 @@ function gModule:__init(inputs,outputs) -- check for unused inputs or unused split() outputs if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then local nUnused = node.data.nSplitOutputs - #node.children - error(string.format("%s of split(%s) outputs are unused", nUnused, node.data.nSplitOutputs)) + local debugLabel = node.data.annotations._debugLabel + local errStr = + "%s of split(%s) outputs from the node declared at %s are unused" + error(string.format(errStr, nUnused, node.data.nSplitOutputs, + debugLabel)) end -- set data.forwardNodeId for node:label() output @@ -27,16 +27,24 @@ function Module:__call__(...) local input = ... if nArgs == 1 and input == nil then - error('what is this in the input? nil') + error(utils.expectingNodeErrorMessage(input, 'inputs', 1)) + end + -- Disallow passing empty table, in case someone passes a table with some + -- typo'd variable name in. + if type(input) == 'table' and next(input) == nil then + error('cannot pass an empty table of inputs. To indicate no incoming ' .. + 'connections, leave the second set of parens blank.') end if not istable(input) then input = {input} end local mnode = nngraph.Node({module=self}) - for i,dnode in ipairs(input) do + local dnode + for i = 1, table.maxn(input) do + dnode = input[i] if torch.typename(dnode) ~= 'nngraph.Node' then - error('what is this in the input? ' .. tostring(dnode)) + error(utils.expectingNodeErrorMessage(dnode, 'inputs', i)) end mnode:add(dnode,true) end @@ -20,8 +20,10 @@ end making a graph.]] function nnNode:_makeDebugLabel(dinfo) if dinfo then - self.data.annotations._debugLabel = string.format('[%s]:%d', - dinfo.short_src, dinfo.currentline, dinfo.name) + self.data.annotations._debugLabel = string.format('[%s]:%d_%s', + dinfo.short_src, + dinfo.currentline, + dinfo.name or '') end end @@ -46,7 +48,12 @@ end function nnNode:split(noutput) assert(noutput >= 2, "splitting to one output is not supported") local debugLabel = self.data.annotations._debugLabel - local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}}) + -- Specify the source location where :split is called. + local dinfo = debug.getinfo(2, 'Sl') + local splitLoc = string.format(' split at [%s]:%d', + dinfo.short_src, + dinfo.currentline) + local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. splitLoc .. '-mnode'}}) mnode:add(self,true) local selectnodes = {} diff --git a/test/test_debug.lua b/test/test_debug.lua new file mode 100644 index 0000000..f5c8b06 --- /dev/null +++ b/test/test_debug.lua @@ -0,0 +1,80 @@ +local totem = require 'totem' +require 'nngraph' +local tests = totem.TestSuite() +local tester = totem.Tester() + +function tests.whatIsThisInTheInput() + tester:assertErrorPattern( + function() + local inp1, inp2 = nn.Identity()(), nn.Identity() -- missing 2nd parens + local lin = nn.Linear(20, 10)(nn.CMulTable(){inp1, inp2}) + end, + 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' .. + 'only valid thing to pass is an instance of nngraph%.Node') + + tester:assertErrorPattern( + function() + -- pass-through module, again with same mistake + local graphNode, nnModule = nn.Identity()(), nn.Identity() + return nn.gModule({graphNode, nnModule}, {graphNode}) + end, + 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' .. + 'only valid thing to pass is an instance of nngraph%.Node') + + tester:assertErrorPattern( + function() + local input = nn.Identity()() + local out1 = nn.Linear(20, 10)(input) + local out2 = nn.Sigmoid()(input) + local unconnectedOut = nn.Linear(2, 3) + return nn.gModule({input}, {out1, out2, unconnectedOut}) + end, + 'outputs%[3%] is an nn%.Module, specifically a nn%.Linear, but the ' .. + 'only valid thing to pass is an instance of nngraph%.Node') + + -- Check for detecting a nil in the middle of a table. + tester:assertErrorPattern( + function() + local input = nn.Identity()() + local out1 = nn.Tanh()(input) + local out2 = nn.Sigmoid()(input) + -- nil here is simulating a mis-spelt variable name + return nn.gModule({input}, {out1, nil, out2}) + end, + 'outputs%[2%] is nil %(typo / bad index%?%)') + + tester:assertErrorPattern( + function() + -- Typo variable name returns nil, meaning an empty table + local input = nn.Identity()({aNonExistentVariable}) + end, + 'cannot pass an empty table of inputs%. To indicate no incoming ' .. + 'connections, leave the second set of parens blank%.') +end + +function tests.splitUnused() + -- Need to do debuginfo on the same lines as the other code here to match + -- what debug.getinfo inside those calls will return + local dInfoDeclare, dInfoSplit + local input = nn.Identity()(); dInfoDeclare = debug.getinfo(1, 'Sl') + local output, unused = input:split(2); dInfoSplit = debug.getinfo(1, 'Sl') + + local function willCrash() + return nn.gModule({input}, {output}) + end + + -- Work out what strings will be in the error message + local declareLoc = string.format('%%[%s%%]:%d_', + dInfoDeclare.short_src, + dInfoDeclare.currentline) + local splitLoc = string.format('%%[%s%%]:%d', + dInfoSplit.short_src, + dInfoSplit.currentline) + + tester:assertErrorPattern( + willCrash, + '1 of split%(2%) outputs from the node declared at ' .. + declareLoc .. ' split at ' .. splitLoc .. '%-mnode are unused') +end + +tester:add(tests):run() @@ -8,4 +8,21 @@ function utils.istable(x) return type(x) == 'table' and not torch.typename(x) end +--[[ Returns a useful error message when a nngraph.Node is expected. ]] +function utils.expectingNodeErrorMessage(badVal, array, idx) + if badVal == nil then + return string.format('%s[%d] is nil (typo / bad index?)', array, idx) + elseif torch.isTypeOf(badVal, 'nn.Module') then + local errStr = '%s[%d] is an nn.Module, specifically a %s, but the ' .. + 'only valid thing to pass is an instance of ' .. + 'nngraph.Node. Did you forget a second set of parens, ' .. + 'which convert a nn.Module to a nngraph.Node?' + return string.format(errStr, array, idx, torch.typename(badVal)) + else + local errStr = '%s[%d] should be an nngraph.Node but is of type %s' + return string.format(errStr, array, idx, + torch.typename(badVal) or type(badVal)) + end +end + return utils |