Welcome to mirror list, hosted at ThFree Co, Russian Federation.

simple_print.lua - github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 87bf15280c866b18f39a95e55bb9555149c23b46 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
local function removeNodeFromEdges(node_id, edges)
   local from_nodes = {}
   local to_nodes = {}
   -- remove edges
   local idx = 1
   while idx <= #edges do
      local edge = edges[idx]
      if edge.source == node_id then
         local to_node = edges[idx].target
         table.insert(to_nodes, to_node)
         table.remove(edges, idx)
      elseif edge.target == node_id then
         local from_node = edges[idx].source
         table.insert(from_nodes, from_node)
         table.remove(edges, idx)
      else
         idx = idx + 1
      end
   end

   -- add new edges
   for _, f in pairs(from_nodes) do
      for _, t in pairs(to_nodes) do
         local edge = {source = f, target= t}
         table.insert(edges, edge)
      end
   end

   return edges
end

local function isNodeGood(node)
   return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity')
end

local function reIndexNodes(nodes, edges)
   -- make reverse map
   local rev_map = {}
   for idx = 1, #nodes do
      rev_map[nodes[idx].id] = idx
      nodes[idx].id = idx
   end
   for idx = 1, #edges do
      local edge = edges[idx]
      edge.source = rev_map[edge.source]
      edge.target = rev_map[edge.target]
   end
   return nodes, edges
end

local function cleanGraph(nodes, edges)
   local idx = 1
   while idx <= #nodes do
      local node = nodes[idx]
      if isNodeGood(node.orig_node) then
         idx = idx + 1
      else
         local id = node.id
         table.remove(nodes, idx)
         edges = removeNodeFromEdges(id, edges)
      end
   end
   return reIndexNodes(nodes, edges)
end

local function loadGraph(graph)
   local nodes = {}
   local edges = {}
   for _, node in ipairs(graph.nodes) do
      local idx = node.id
      table.insert(nodes, {id=idx, orig_node = node} )
      for ich = 1, #node.children do
         table.insert( edges, {source = idx, target = node.children[ich].id})
      end
   end
   nodes, edges = cleanGraph(nodes, edges)
   return nodes , edges
end

local M = {}

function M.todot( graph, title )
   local nodes, edges = loadGraph(graph)
   local str = {}
   table.insert(str,'digraph G {\n')
   if title then
      table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
   end
   table.insert(str,'node [shape = oval]; ')
   local nodelabels = {}
   for i,node in ipairs(nodes) do
      local true_node = node.orig_node
      local l =  '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"'
      nodelabels[i] = 'n' .. true_node.id
      table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];')
   end
   table.insert(str,'\n')
   for i,edge in ipairs(edges) do
      table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n')
   end
   table.insert(str,'}')
   return table.concat(str,'')
end

function M.dot(g,title,fname)
   local gv = M.todot(g, title)
   local fngv = (fname or os.tmpname()) .. '.dot'
   local fgv = io.open(fngv,'w')
   fgv:write(gv)
   fgv:close()
   local fnsvg = (fname or os.tmpname()) .. '.svg'
   os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv)
   if not fname then
      require 'qtsvg'
      local qs = qt.QSvgWidget(fnsvg)
      qs:show()
      os.remove(fngv)
      os.remove(fnsvg)
      -- print(fngv,fnpng)
      return qs
   end
end

return M