diff options
author | koray kavukcuoglu <koray@kavukcuoglu.org> | 2015-11-22 01:26:45 +0300 |
---|---|---|
committer | koray kavukcuoglu <koray@kavukcuoglu.org> | 2015-11-22 01:26:45 +0300 |
commit | 2d4351704eee6f97c0652a4680b88cfae64ebb0d (patch) | |
tree | 9829cd593df3d301b4616920aa9d67bc15a6051b | |
parent | 233e126b08414b9e4fad022da02c6f3af3225493 (diff) | |
parent | a7a661df1ddecc620a2e6a53de384b5e8995896a (diff) |
Merge pull request #93 from malcolmreynolds/improve_error_messages
Make error messages clearer, disallow empty table in inputs.
-rw-r--r-- | gmodule.lua | 25 | ||||
-rw-r--r-- | init.lua | 14 | ||||
-rw-r--r-- | node.lua | 13 | ||||
-rw-r--r-- | test/test_debug.lua | 80 | ||||
-rw-r--r-- | utils.lua | 31 |
5 files changed, 147 insertions, 16 deletions
diff --git a/gmodule.lua b/gmodule.lua index a360d2f..99698c0 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -47,18 +47,19 @@ function gModule:__init(inputs,outputs) -- we will define a dummy output node that connects all output modules -- into itself. This will be the output for the forward graph and -- input point for the backward graph + local node local outnode = nngraph.Node({input={}}) - for i,n in ipairs(outputs) do - if torch.typename(n) ~= 'nngraph.Node' then - error(string.format('what is this in the outputs[%s]? %s', - i, tostring(n))) + for i = 1, utils.tableMaxN(outputs) do + node = outputs[i] + if torch.typename(node) ~= 'nngraph.Node' then + error(utils.expectingNodeErrorMessage(node, 'outputs', i)) end - outnode:add(n,true) + outnode:add(node, true) end - for i,n in ipairs(inputs) do - if torch.typename(n) ~= 'nngraph.Node' then - error(string.format('what is this in the inputs[%s]? %s', - i, tostring(n))) + for i = 1, utils.tableMaxN(inputs) do + node = inputs[i] + if torch.typename(node) ~= 'nngraph.Node' then + error(utils.expectingNodeErrorMessage(node, 'inputs', i)) end end -- We add also a dummy input node. @@ -122,7 +123,11 @@ function gModule:__init(inputs,outputs) -- check for unused inputs or unused split() outputs if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then local nUnused = node.data.nSplitOutputs - #node.children - error(string.format("%s of split(%s) outputs are unused", nUnused, node.data.nSplitOutputs)) + local debugLabel = node.data.annotations._debugLabel + local errStr = + "%s of split(%s) outputs from the node declared at %s are unused" + error(string.format(errStr, nUnused, node.data.nSplitOutputs, + debugLabel)) end -- set data.forwardNodeId for node:label() output @@ -27,16 +27,24 @@ function Module:__call__(...) local input = ... if nArgs == 1 and input == nil then - error('what is this in the input? nil') + error(utils.expectingNodeErrorMessage(input, 'inputs', 1)) + end + -- Disallow passing empty table, in case someone passes a table with some + -- typo'd variable name in. + if type(input) == 'table' and next(input) == nil then + error('cannot pass an empty table of inputs. To indicate no incoming ' .. + 'connections, leave the second set of parens blank.') end if not istable(input) then input = {input} end local mnode = nngraph.Node({module=self}) - for i,dnode in ipairs(input) do + local dnode + for i = 1, utils.tableMaxN(input) do + dnode = input[i] if torch.typename(dnode) ~= 'nngraph.Node' then - error('what is this in the input? ' .. tostring(dnode)) + error(utils.expectingNodeErrorMessage(dnode, 'inputs', i)) end mnode:add(dnode,true) end @@ -20,8 +20,10 @@ end making a graph.]] function nnNode:_makeDebugLabel(dinfo) if dinfo then - self.data.annotations._debugLabel = string.format('[%s]:%d', - dinfo.short_src, dinfo.currentline, dinfo.name) + self.data.annotations._debugLabel = string.format('[%s]:%d_%s', + dinfo.short_src, + dinfo.currentline, + dinfo.name or '') end end @@ -46,7 +48,12 @@ end function nnNode:split(noutput) assert(noutput >= 2, "splitting to one output is not supported") local debugLabel = self.data.annotations._debugLabel - local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}}) + -- Specify the source location where :split is called. + local dinfo = debug.getinfo(2, 'Sl') + local splitLoc = string.format(' split at [%s]:%d', + dinfo.short_src, + dinfo.currentline) + local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. splitLoc .. '-mnode'}}) mnode:add(self,true) local selectnodes = {} diff --git a/test/test_debug.lua b/test/test_debug.lua new file mode 100644 index 0000000..f5c8b06 --- /dev/null +++ b/test/test_debug.lua @@ -0,0 +1,80 @@ +local totem = require 'totem' +require 'nngraph' +local tests = totem.TestSuite() +local tester = totem.Tester() + +function tests.whatIsThisInTheInput() + tester:assertErrorPattern( + function() + local inp1, inp2 = nn.Identity()(), nn.Identity() -- missing 2nd parens + local lin = nn.Linear(20, 10)(nn.CMulTable(){inp1, inp2}) + end, + 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' .. + 'only valid thing to pass is an instance of nngraph%.Node') + + tester:assertErrorPattern( + function() + -- pass-through module, again with same mistake + local graphNode, nnModule = nn.Identity()(), nn.Identity() + return nn.gModule({graphNode, nnModule}, {graphNode}) + end, + 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' .. + 'only valid thing to pass is an instance of nngraph%.Node') + + tester:assertErrorPattern( + function() + local input = nn.Identity()() + local out1 = nn.Linear(20, 10)(input) + local out2 = nn.Sigmoid()(input) + local unconnectedOut = nn.Linear(2, 3) + return nn.gModule({input}, {out1, out2, unconnectedOut}) + end, + 'outputs%[3%] is an nn%.Module, specifically a nn%.Linear, but the ' .. + 'only valid thing to pass is an instance of nngraph%.Node') + + -- Check for detecting a nil in the middle of a table. + tester:assertErrorPattern( + function() + local input = nn.Identity()() + local out1 = nn.Tanh()(input) + local out2 = nn.Sigmoid()(input) + -- nil here is simulating a mis-spelt variable name + return nn.gModule({input}, {out1, nil, out2}) + end, + 'outputs%[2%] is nil %(typo / bad index%?%)') + + tester:assertErrorPattern( + function() + -- Typo variable name returns nil, meaning an empty table + local input = nn.Identity()({aNonExistentVariable}) + end, + 'cannot pass an empty table of inputs%. To indicate no incoming ' .. + 'connections, leave the second set of parens blank%.') +end + +function tests.splitUnused() + -- Need to do debuginfo on the same lines as the other code here to match + -- what debug.getinfo inside those calls will return + local dInfoDeclare, dInfoSplit + local input = nn.Identity()(); dInfoDeclare = debug.getinfo(1, 'Sl') + local output, unused = input:split(2); dInfoSplit = debug.getinfo(1, 'Sl') + + local function willCrash() + return nn.gModule({input}, {output}) + end + + -- Work out what strings will be in the error message + local declareLoc = string.format('%%[%s%%]:%d_', + dInfoDeclare.short_src, + dInfoDeclare.currentline) + local splitLoc = string.format('%%[%s%%]:%d', + dInfoSplit.short_src, + dInfoSplit.currentline) + + tester:assertErrorPattern( + willCrash, + '1 of split%(2%) outputs from the node declared at ' .. + declareLoc .. ' split at ' .. splitLoc .. '%-mnode are unused') +end + +tester:add(tests):run() @@ -8,4 +8,35 @@ function utils.istable(x) return type(x) == 'table' and not torch.typename(x) end +--[[ Returns a useful error message when a nngraph.Node is expected. ]] +function utils.expectingNodeErrorMessage(badVal, array, idx) + if badVal == nil then + return string.format('%s[%d] is nil (typo / bad index?)', array, idx) + elseif torch.isTypeOf(badVal, 'nn.Module') then + local errStr = '%s[%d] is an nn.Module, specifically a %s, but the ' .. + 'only valid thing to pass is an instance of ' .. + 'nngraph.Node. Did you forget a second set of parens, ' .. + 'which convert a nn.Module to a nngraph.Node?' + return string.format(errStr, array, idx, torch.typename(badVal)) + else + local errStr = '%s[%d] should be an nngraph.Node but is of type %s' + return string.format(errStr, array, idx, + torch.typename(badVal) or type(badVal)) + end +end + +--[[ Lua 5.2+ removed table.maxn, provide fallback implementation. ]] +if table.maxn then + utils.tableMaxN = table.maxn +else + function utils.tableMaxN(tbl) + local max = 0 + for k, v in pairs(tbl) do + if type(k) == 'number' and k > max then + max = k + end + end + return max + end + end return utils |