Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2015-11-22 01:26:45 +0300
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2015-11-22 01:26:45 +0300
commit2d4351704eee6f97c0652a4680b88cfae64ebb0d (patch)
tree9829cd593df3d301b4616920aa9d67bc15a6051b
parent233e126b08414b9e4fad022da02c6f3af3225493 (diff)
parenta7a661df1ddecc620a2e6a53de384b5e8995896a (diff)
Merge pull request #93 from malcolmreynolds/improve_error_messages
Make error messages clearer, disallow empty table in inputs.
-rw-r--r--gmodule.lua25
-rw-r--r--init.lua14
-rw-r--r--node.lua13
-rw-r--r--test/test_debug.lua80
-rw-r--r--utils.lua31
5 files changed, 147 insertions, 16 deletions
diff --git a/gmodule.lua b/gmodule.lua
index a360d2f..99698c0 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -47,18 +47,19 @@ function gModule:__init(inputs,outputs)
-- we will define a dummy output node that connects all output modules
-- into itself. This will be the output for the forward graph and
-- input point for the backward graph
+ local node
local outnode = nngraph.Node({input={}})
- for i,n in ipairs(outputs) do
- if torch.typename(n) ~= 'nngraph.Node' then
- error(string.format('what is this in the outputs[%s]? %s',
- i, tostring(n)))
+ for i = 1, utils.tableMaxN(outputs) do
+ node = outputs[i]
+ if torch.typename(node) ~= 'nngraph.Node' then
+ error(utils.expectingNodeErrorMessage(node, 'outputs', i))
end
- outnode:add(n,true)
+ outnode:add(node, true)
end
- for i,n in ipairs(inputs) do
- if torch.typename(n) ~= 'nngraph.Node' then
- error(string.format('what is this in the inputs[%s]? %s',
- i, tostring(n)))
+ for i = 1, utils.tableMaxN(inputs) do
+ node = inputs[i]
+ if torch.typename(node) ~= 'nngraph.Node' then
+ error(utils.expectingNodeErrorMessage(node, 'inputs', i))
end
end
-- We add also a dummy input node.
@@ -122,7 +123,11 @@ function gModule:__init(inputs,outputs)
-- check for unused inputs or unused split() outputs
if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #node.children then
local nUnused = node.data.nSplitOutputs - #node.children
- error(string.format("%s of split(%s) outputs are unused", nUnused, node.data.nSplitOutputs))
+ local debugLabel = node.data.annotations._debugLabel
+ local errStr =
+ "%s of split(%s) outputs from the node declared at %s are unused"
+ error(string.format(errStr, nUnused, node.data.nSplitOutputs,
+ debugLabel))
end
-- set data.forwardNodeId for node:label() output
diff --git a/init.lua b/init.lua
index 9b117a6..1eae7cb 100644
--- a/init.lua
+++ b/init.lua
@@ -27,16 +27,24 @@ function Module:__call__(...)
local input = ...
if nArgs == 1 and input == nil then
- error('what is this in the input? nil')
+ error(utils.expectingNodeErrorMessage(input, 'inputs', 1))
+ end
+ -- Disallow passing empty table, in case someone passes a table with some
+ -- typo'd variable name in.
+ if type(input) == 'table' and next(input) == nil then
+ error('cannot pass an empty table of inputs. To indicate no incoming ' ..
+ 'connections, leave the second set of parens blank.')
end
if not istable(input) then
input = {input}
end
local mnode = nngraph.Node({module=self})
- for i,dnode in ipairs(input) do
+ local dnode
+ for i = 1, utils.tableMaxN(input) do
+ dnode = input[i]
if torch.typename(dnode) ~= 'nngraph.Node' then
- error('what is this in the input? ' .. tostring(dnode))
+ error(utils.expectingNodeErrorMessage(dnode, 'inputs', i))
end
mnode:add(dnode,true)
end
diff --git a/node.lua b/node.lua
index 0f2261f..3842605 100644
--- a/node.lua
+++ b/node.lua
@@ -20,8 +20,10 @@ end
making a graph.]]
function nnNode:_makeDebugLabel(dinfo)
if dinfo then
- self.data.annotations._debugLabel = string.format('[%s]:%d',
- dinfo.short_src, dinfo.currentline, dinfo.name)
+ self.data.annotations._debugLabel = string.format('[%s]:%d_%s',
+ dinfo.short_src,
+ dinfo.currentline,
+ dinfo.name or '')
end
end
@@ -46,7 +48,12 @@ end
function nnNode:split(noutput)
assert(noutput >= 2, "splitting to one output is not supported")
local debugLabel = self.data.annotations._debugLabel
- local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}})
+ -- Specify the source location where :split is called.
+ local dinfo = debug.getinfo(2, 'Sl')
+ local splitLoc = string.format(' split at [%s]:%d',
+ dinfo.short_src,
+ dinfo.currentline)
+ local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. splitLoc .. '-mnode'}})
mnode:add(self,true)
local selectnodes = {}
diff --git a/test/test_debug.lua b/test/test_debug.lua
new file mode 100644
index 0000000..f5c8b06
--- /dev/null
+++ b/test/test_debug.lua
@@ -0,0 +1,80 @@
+local totem = require 'totem'
+require 'nngraph'
+local tests = totem.TestSuite()
+local tester = totem.Tester()
+
+function tests.whatIsThisInTheInput()
+ tester:assertErrorPattern(
+ function()
+ local inp1, inp2 = nn.Identity()(), nn.Identity() -- missing 2nd parens
+ local lin = nn.Linear(20, 10)(nn.CMulTable(){inp1, inp2})
+ end,
+ 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' ..
+ 'only valid thing to pass is an instance of nngraph%.Node')
+
+ tester:assertErrorPattern(
+ function()
+ -- pass-through module, again with same mistake
+ local graphNode, nnModule = nn.Identity()(), nn.Identity()
+ return nn.gModule({graphNode, nnModule}, {graphNode})
+ end,
+ 'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' ..
+ 'only valid thing to pass is an instance of nngraph%.Node')
+
+ tester:assertErrorPattern(
+ function()
+ local input = nn.Identity()()
+ local out1 = nn.Linear(20, 10)(input)
+ local out2 = nn.Sigmoid()(input)
+ local unconnectedOut = nn.Linear(2, 3)
+ return nn.gModule({input}, {out1, out2, unconnectedOut})
+ end,
+ 'outputs%[3%] is an nn%.Module, specifically a nn%.Linear, but the ' ..
+ 'only valid thing to pass is an instance of nngraph%.Node')
+
+ -- Check for detecting a nil in the middle of a table.
+ tester:assertErrorPattern(
+ function()
+ local input = nn.Identity()()
+ local out1 = nn.Tanh()(input)
+ local out2 = nn.Sigmoid()(input)
+ -- nil here is simulating a mis-spelt variable name
+ return nn.gModule({input}, {out1, nil, out2})
+ end,
+ 'outputs%[2%] is nil %(typo / bad index%?%)')
+
+ tester:assertErrorPattern(
+ function()
+ -- Typo variable name returns nil, meaning an empty table
+ local input = nn.Identity()({aNonExistentVariable})
+ end,
+ 'cannot pass an empty table of inputs%. To indicate no incoming ' ..
+ 'connections, leave the second set of parens blank%.')
+end
+
+function tests.splitUnused()
+ -- Need to do debuginfo on the same lines as the other code here to match
+ -- what debug.getinfo inside those calls will return
+ local dInfoDeclare, dInfoSplit
+ local input = nn.Identity()(); dInfoDeclare = debug.getinfo(1, 'Sl')
+ local output, unused = input:split(2); dInfoSplit = debug.getinfo(1, 'Sl')
+
+ local function willCrash()
+ return nn.gModule({input}, {output})
+ end
+
+ -- Work out what strings will be in the error message
+ local declareLoc = string.format('%%[%s%%]:%d_',
+ dInfoDeclare.short_src,
+ dInfoDeclare.currentline)
+ local splitLoc = string.format('%%[%s%%]:%d',
+ dInfoSplit.short_src,
+ dInfoSplit.currentline)
+
+ tester:assertErrorPattern(
+ willCrash,
+ '1 of split%(2%) outputs from the node declared at ' ..
+ declareLoc .. ' split at ' .. splitLoc .. '%-mnode are unused')
+end
+
+tester:add(tests):run()
diff --git a/utils.lua b/utils.lua
index c0bccb2..1b39607 100644
--- a/utils.lua
+++ b/utils.lua
@@ -8,4 +8,35 @@ function utils.istable(x)
return type(x) == 'table' and not torch.typename(x)
end
+--[[ Returns a useful error message when a nngraph.Node is expected. ]]
+function utils.expectingNodeErrorMessage(badVal, array, idx)
+ if badVal == nil then
+ return string.format('%s[%d] is nil (typo / bad index?)', array, idx)
+ elseif torch.isTypeOf(badVal, 'nn.Module') then
+ local errStr = '%s[%d] is an nn.Module, specifically a %s, but the ' ..
+ 'only valid thing to pass is an instance of ' ..
+ 'nngraph.Node. Did you forget a second set of parens, ' ..
+ 'which convert a nn.Module to a nngraph.Node?'
+ return string.format(errStr, array, idx, torch.typename(badVal))
+ else
+ local errStr = '%s[%d] should be an nngraph.Node but is of type %s'
+ return string.format(errStr, array, idx,
+ torch.typename(badVal) or type(badVal))
+ end
+end
+
+--[[ Lua 5.2+ removed table.maxn, provide fallback implementation. ]]
+if table.maxn then
+ utils.tableMaxN = table.maxn
+else
+ function utils.tableMaxN(tbl)
+ local max = 0
+ for k, v in pairs(tbl) do
+ if type(k) == 'number' and k > max then
+ max = k
+ end
+ end
+ return max
+ end
+ end
return utils