diff options
author | koray kavukcuoglu <koray@kavukcuoglu.org> | 2014-07-20 12:59:30 +0400 |
---|---|---|
committer | koray kavukcuoglu <koray@kavukcuoglu.org> | 2014-07-20 12:59:30 +0400 |
commit | 1bbf0d7603616d1bdcbadf567694f60bbf6a9a30 (patch) | |
tree | f999592844a346137e6f78007c09aa92e86c8190 | |
parent | 8dcabe382d631dec0e2e91beffd18474e246d726 (diff) | |
parent | 0a3adb8afe4b1752b527d9702a7942f5e02136ac (diff) |
Merge pull request #30 from jameskirkpatrick/topic_simple_print
Print graphs that ignores certain nodes
-rw-r--r-- | init.lua | 2 | ||||
-rw-r--r-- | simple_print.lua | 125 | ||||
-rw-r--r-- | test.lua | 14 |
3 files changed, 141 insertions, 0 deletions
@@ -14,6 +14,8 @@ local istensor = utils.istensor local istable = utils.istable local istorchclass = utils.istorchclass +-- simpler todot functions +nngraph.simple_print = paths.dofile('simple_print.lua') -- Modify the __call function to hack into nn.Module diff --git a/simple_print.lua b/simple_print.lua new file mode 100644 index 0000000..878db3e --- /dev/null +++ b/simple_print.lua @@ -0,0 +1,125 @@ +local function removeNodeFromEdges(node_id, edges) + local from_nodes = {} + local to_nodes = {} + -- remove edges + local idx = 1 + while idx <= #edges do + local edge = edges[idx] + if edge.source == node_id then + local to_node = edges[idx].target + table.insert(to_nodes, to_node) + table.remove(edges, idx) + elseif edge.target == node_id then + local from_node = edges[idx].source + table.insert(from_nodes, from_node) + table.remove(edges, idx) + else + idx = idx + 1 + end + end + + -- add new edges + for _, f in pairs(from_nodes) do + for _, t in pairs(to_nodes) do + local edge = {source = f, target= t} + table.insert(edges, edge) + end + end + + return edges +end + +local function isNodeGood(node) + return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity') +end + +local function reIndexNodes(nodes, edges) + -- make reverse map + local rev_map = {} + for idx = 1, #nodes do + rev_map[nodes[idx].id] = idx + nodes[idx].id = idx + end + for idx = 1, #edges do + local edge = edges[idx] + edge.source = rev_map[edge.source] + edge.target = rev_map[edge.target] + end + return nodes, edges +end + +local function cleanGraph(nodes, edges) + local idx = 1 + while idx <= #nodes do + local node = nodes[idx] + if isNodeGood(node.orig_node) then + idx = idx + 1 + else + local id = node.id + table.remove(nodes, idx) + edges = removeNodeFromEdges(id, edges) + end + end + return reIndexNodes(nodes, edges) +end + +local function loadGraph(graph) + local nodes = {} + local edges = {} + for _, node in ipairs(graph.nodes) do + local idx = node.id + table.insert(nodes, {id=idx, orig_node = node} ) + for ich = 1, #node.children do + table.insert( edges, {source = idx, target = node.children[ich].id}) + end + end + nodes, edges = cleanGraph(nodes, edges) + return nodes , edges +end + +local M = {} + +function M.todot( graph, title ) + local nodes, edges = loadGraph(graph) + local str = {} + table.insert(str,'digraph G {\n') + if title then + table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n') + end + table.insert(str,'node [shape = oval]; ') + local nodelabels = {} + for i,node in ipairs(nodes) do + local true_node = node.orig_node + local l = '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"' + nodelabels[i] = 'n' .. true_node.id + table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];') + end + table.insert(str,'\n') + for i,edge in ipairs(edges) do + table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n') + end + table.insert(str,'}') + return table.concat(str,'') +end + +function M.dot(g,title,fname) + local gv = M.todot(g, title) + local fngv = (fname or os.tmpname()) .. '.dot' + local fgv = io.open(fngv,'w') + fgv:write(gv) + fgv:close() + local fnsvg = (fname or os.tmpname()) .. '.svg' + os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv) + if not fname then + require 'qtsvg' + local qs = qt.QSvgWidget(fnsvg) + qs:show() + os.remove(fngv) + os.remove(fnsvg) + -- print(fngv,fnpng) + return qs + end +end + +return M + @@ -242,4 +242,18 @@ function t7() my.eq(gradInput, torch.Tensor{-5}, "gradInput of a fork") end +function t8() + local in1 = nn.Identity()() + local m = nn.Linear(10,10)(in1) + local out1 = nn.Tanh()(m) + local out2 = nn.Tanh()(m) + local out = nn.CAddTable(){out1, out2} + local mod = nn.gModule({in1}, {out}) + + local dot = nngraph.simple_print.todot(mod.fg, 'bogus') + print (dot) + nngraph.simple_print.dot(mod.fg, 'bogus', 'new') + graph.dot(mod.fg, 'bogus', 'old') +end -- t2() +t8() |