diff options
author | Ivo Danihelka <ivo@danihelka.net> | 2014-03-03 21:59:58 +0400 |
---|---|---|
committer | Ivo Danihelka <ivo@danihelka.net> | 2014-03-03 21:59:58 +0400 |
commit | 370e0463089af2f876b29825620c14bfc50a7a14 (patch) | |
tree | 26ee1c7105d0874e123376a45c17e1e8a4cfbe59 /graphinspecting.lua | |
parent | 689810747744678586149ca9022e11d7d119a4cc (diff) |
Saved the graph .svg when nngraph.setDebug(true) is enabled.
Diffstat (limited to 'graphinspecting.lua')
-rw-r--r-- | graphinspecting.lua | 119 |
1 files changed, 119 insertions, 0 deletions
diff --git a/graphinspecting.lua b/graphinspecting.lua new file mode 100644 index 0000000..045be31 --- /dev/null +++ b/graphinspecting.lua @@ -0,0 +1,119 @@ + +-- The findCurrentNode() depends on the names of the +-- local variables in the nngraph.gModule source code. +local function findCurrentNode() + for level = 2, math.huge do + local info = debug.getinfo(level, "n") + if info == nil then + return nil + end + + local funcName = info.name + if funcName == "neteval" then + local varName, node = debug.getlocal(level, 1) + if varName == "node" then + return node + end + end + end +end + +-- Runs the func and calls onError(failedNode, ...) on an error. +-- The stack trace is inspected to find the failedNode. +local function runChecked(func, onError, ...) + -- The current node needs to be searched-for, before unrolling the stack. + local failedNode + local function errorHandler(message) + -- The stack traceback is added only if not already present. + if not string.find(message, 'stack traceback:\n', 1, true) then + message = debug.traceback(message, 2) + end + failedNode = findCurrentNode() + return message + end + + local ok, result = xpcall(func, errorHandler) + if ok then + return result + end + + onError(failedNode, ...) + -- Passing the level 0 avoids adding an additional error position info + -- to the message. + error(result, 0) +end + +local function customToDot(graph, title, failedNode) + local str = graph:todot(title) + if not failedNode then + return str + end + + local failedNodeId = nil + for i, node in ipairs(graph.nodes) do + if node.data == failedNode.data then + failedNodeId = node.id + break + end + end + + if failedNodeId ~= nil then + -- The closing '}' is removed. + -- And red fillcolor is specified for the failedNode. + str = string.gsub(str, '}%s*$', '') + str = str .. string.format('n%s[style=filled, fillcolor=red];\n}', + failedNodeId) + end + return str +end + +local function saveSvg(svgPathPrefix, dotStr) + io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix)) + local dotPath = svgPathPrefix .. '.dot' + local dotFile = io.open(dotPath, 'w') + dotFile:write(dotStr) + dotFile:close() + + local svgPath = svgPathPrefix .. '.svg' + local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath) + os.execute(cmd) +end + +local function onError(failedNode, gmodule) + local nInputs = gmodule.nInputs or #gmodule.innode.children + local svgPathPrefix = gmodule.name or string.format( + 'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children) + local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode) + saveSvg(svgPathPrefix, dotStr) +end + +local origFuncs = { + runForwardFunction = nn.gModule.runForwardFunction, + updateGradInput = nn.gModule.updateGradInput, + accGradParameters = nn.gModule.accGradParameters, +} + +-- When debug is enabled, +-- a gmodule.name .. '.svg' will be saved +-- if an exception occurs in a graph execution. +-- The problematic node will be marked by red color. +function nngraph.setDebug(enable) + if not enable then + -- When debug is disabled, + -- the origFuncs are restored on nn.gModule. + for funcName, origFunc in pairs(origFuncs) do + nn.gModule[funcName] = origFunc + end + return + end + + for funcName, origFunc in pairs(origFuncs) do + nn.gModule[funcName] = function(...) + local args = {...} + local gmodule = args[1] + return runChecked(function() + return origFunc(unpack(args)) + end, onError, gmodule) + end + end +end |