Welcome to mirror list, hosted at ThFree Co, Russian Federation.

graphinspecting.lua - github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 8e858cfc57f92cec777c58a5ef3e68cabe595a48 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

-- The findCurrentNode() depends on the names of the
-- local variables in the nngraph.gModule source code.
local function findCurrentNode()
   for level = 2, math.huge do
      local info = debug.getinfo(level, "n")
      if info == nil then
         return nil
      end

      local funcName = info.name
      if funcName == "neteval" then
         local varName, node = debug.getlocal(level, 1)
         if varName == "node" then
            return node
         end
      end
   end
end

-- Runs the func and calls onError(failedNode, ...) on an error.
-- The stack trace is inspected to find the failedNode.
local function runChecked(func, onError, ...)
   -- The current node needs to be searched-for, before unrolling the stack.
   local failedNode
   local function errorHandler(message)
      -- The stack traceback is added only if not already present.
      if not string.find(message, 'stack traceback:\n', 1, true) then
         message = debug.traceback(message, 2)
      end
      failedNode = findCurrentNode()
      return message
   end

   local ok, result = xpcall(func, errorHandler)
   if ok then
      return result
   end

   onError(failedNode, ...)
   -- Passing the level 0 avoids adding an additional error position info
   -- to the message.
   error(result, 0)
end

local function customToDot(graph, title, failedNode)
   local str = graph:todot(title)
   if not failedNode then
      return str
   end

   local failedNodeId = nil
   for i, node in ipairs(graph.nodes) do
      if node.data == failedNode.data then
         failedNodeId = node.id
         break
      end
   end

   if failedNodeId ~= nil then
      -- The closing '}' is removed.
      -- And red fillcolor is specified for the failedNode.
      str = string.gsub(str, '}%s*$', '')
      str = str .. string.format('n%s[style=filled, fillcolor=red];\n}',
      failedNodeId)
   end
   return str
end

local function saveSvg(svgPathPrefix, dotStr)
   io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix))
   local dotPath = svgPathPrefix .. '.dot'
   local dotFile = io.open(dotPath, 'w')
   dotFile:write(dotStr)
   dotFile:close()

   local svgPath = svgPathPrefix .. '.svg'
   local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath)
   os.execute(cmd)
end

local function onError(failedNode, gmodule)
   local nInputs = gmodule.nInputs or #gmodule.innode.children
   local svgPathPrefix = gmodule.name or string.format(
   'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children)
   if paths.filep(svgPathPrefix .. '.svg') then
      svgPathPrefix = svgPathPrefix .. '_' .. paths.basename(os.tmpname())
   end
   local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode)
   saveSvg(svgPathPrefix, dotStr)
end

local origFuncs = {
   runForwardFunction = nn.gModule.runForwardFunction,
   updateGradInput = nn.gModule.updateGradInput,
   accGradParameters = nn.gModule.accGradParameters,
}

-- When debug is enabled,
-- a gmodule.name .. '.svg' will be saved
-- if an exception occurs in a graph execution.
-- The problematic node will be marked by red color.
function nngraph.setDebug(enable)
   if not enable then
      -- When debug is disabled,
      -- the origFuncs are restored on nn.gModule.
      for funcName, origFunc in pairs(origFuncs) do
         nn.gModule[funcName] = origFunc
      end
      return
   end

   for funcName, origFunc in pairs(origFuncs) do
      nn.gModule[funcName] = function(...)
         local args = {...}
         local gmodule = args[1]
	 local unpack = unpack or table.unpack
         return runChecked(function()
            return origFunc(unpack(args))
         end, onError, gmodule)
      end
   end
end

-- Sets node.data.annotations.name for the found nodes.
-- The local variables at the given stack level are inspected.
-- The default stack level is 1 (the function that called annotateNodes()).
function nngraph.annotateNodes(stackLevel)
   stackLevel = stackLevel or 1
   for index = 1, math.huge do
      local varName, varValue = debug.getlocal(stackLevel + 1, index)
      if not varName then
         break
      end
      if torch.typename(varValue) == "nngraph.Node" then
         -- An explicit name is preserved.
         if not varValue.data.annotations.name then
            varValue:annotate({name = varName})
         end
      end
   end
end

--[[
   SVG visualization for gmodule
   TODO: add custom coloring with node types
]]
function nngraph.display(gmodule)
   local ffi = require 'ffi'
   local cmd
   if ffi.os == 'Linux' then
      cmd = 'xdg-open'
   elseif ffi.os == 'OSX' then
      cmd = 'open -a Safari'
   end
   local fname = os.tmpname()
   graph.dot(gmodule.fg, fname, fname)
   os.execute(cmd .. ' ' .. fname .. '.svg')
end