Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2014-07-20 12:59:30 +0400
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2014-07-20 12:59:30 +0400
commit1bbf0d7603616d1bdcbadf567694f60bbf6a9a30 (patch)
treef999592844a346137e6f78007c09aa92e86c8190
parent8dcabe382d631dec0e2e91beffd18474e246d726 (diff)
parent0a3adb8afe4b1752b527d9702a7942f5e02136ac (diff)
Merge pull request #30 from jameskirkpatrick/topic_simple_print
Print graphs that ignores certain nodes
-rw-r--r--init.lua2
-rw-r--r--simple_print.lua125
-rw-r--r--test.lua14
3 files changed, 141 insertions, 0 deletions
diff --git a/init.lua b/init.lua
index 273de8c..b32639b 100644
--- a/init.lua
+++ b/init.lua
@@ -14,6 +14,8 @@ local istensor = utils.istensor
local istable = utils.istable
local istorchclass = utils.istorchclass
+-- simpler todot functions
+nngraph.simple_print = paths.dofile('simple_print.lua')
-- Modify the __call function to hack into nn.Module
diff --git a/simple_print.lua b/simple_print.lua
new file mode 100644
index 0000000..878db3e
--- /dev/null
+++ b/simple_print.lua
@@ -0,0 +1,125 @@
+local function removeNodeFromEdges(node_id, edges)
+ local from_nodes = {}
+ local to_nodes = {}
+ -- remove edges
+ local idx = 1
+ while idx <= #edges do
+ local edge = edges[idx]
+ if edge.source == node_id then
+ local to_node = edges[idx].target
+ table.insert(to_nodes, to_node)
+ table.remove(edges, idx)
+ elseif edge.target == node_id then
+ local from_node = edges[idx].source
+ table.insert(from_nodes, from_node)
+ table.remove(edges, idx)
+ else
+ idx = idx + 1
+ end
+ end
+
+ -- add new edges
+ for _, f in pairs(from_nodes) do
+ for _, t in pairs(to_nodes) do
+ local edge = {source = f, target= t}
+ table.insert(edges, edge)
+ end
+ end
+
+ return edges
+end
+
+local function isNodeGood(node)
+ return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity')
+end
+
+local function reIndexNodes(nodes, edges)
+ -- make reverse map
+ local rev_map = {}
+ for idx = 1, #nodes do
+ rev_map[nodes[idx].id] = idx
+ nodes[idx].id = idx
+ end
+ for idx = 1, #edges do
+ local edge = edges[idx]
+ edge.source = rev_map[edge.source]
+ edge.target = rev_map[edge.target]
+ end
+ return nodes, edges
+end
+
+local function cleanGraph(nodes, edges)
+ local idx = 1
+ while idx <= #nodes do
+ local node = nodes[idx]
+ if isNodeGood(node.orig_node) then
+ idx = idx + 1
+ else
+ local id = node.id
+ table.remove(nodes, idx)
+ edges = removeNodeFromEdges(id, edges)
+ end
+ end
+ return reIndexNodes(nodes, edges)
+end
+
+local function loadGraph(graph)
+ local nodes = {}
+ local edges = {}
+ for _, node in ipairs(graph.nodes) do
+ local idx = node.id
+ table.insert(nodes, {id=idx, orig_node = node} )
+ for ich = 1, #node.children do
+ table.insert( edges, {source = idx, target = node.children[ich].id})
+ end
+ end
+ nodes, edges = cleanGraph(nodes, edges)
+ return nodes , edges
+end
+
+local M = {}
+
+function M.todot( graph, title )
+ local nodes, edges = loadGraph(graph)
+ local str = {}
+ table.insert(str,'digraph G {\n')
+ if title then
+ table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
+ end
+ table.insert(str,'node [shape = oval]; ')
+ local nodelabels = {}
+ for i,node in ipairs(nodes) do
+ local true_node = node.orig_node
+ local l = '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"'
+ nodelabels[i] = 'n' .. true_node.id
+ table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];')
+ end
+ table.insert(str,'\n')
+ for i,edge in ipairs(edges) do
+ table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n')
+ end
+ table.insert(str,'}')
+ return table.concat(str,'')
+end
+
+function M.dot(g,title,fname)
+ local gv = M.todot(g, title)
+ local fngv = (fname or os.tmpname()) .. '.dot'
+ local fgv = io.open(fngv,'w')
+ fgv:write(gv)
+ fgv:close()
+ local fnsvg = (fname or os.tmpname()) .. '.svg'
+ os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv)
+ if not fname then
+ require 'qtsvg'
+ local qs = qt.QSvgWidget(fnsvg)
+ qs:show()
+ os.remove(fngv)
+ os.remove(fnsvg)
+ -- print(fngv,fnpng)
+ return qs
+ end
+end
+
+return M
+
diff --git a/test.lua b/test.lua
index d774eb6..3f9ccb7 100644
--- a/test.lua
+++ b/test.lua
@@ -242,4 +242,18 @@ function t7()
my.eq(gradInput, torch.Tensor{-5}, "gradInput of a fork")
end
+function t8()
+ local in1 = nn.Identity()()
+ local m = nn.Linear(10,10)(in1)
+ local out1 = nn.Tanh()(m)
+ local out2 = nn.Tanh()(m)
+ local out = nn.CAddTable(){out1, out2}
+ local mod = nn.gModule({in1}, {out})
+
+ local dot = nngraph.simple_print.todot(mod.fg, 'bogus')
+ print (dot)
+ nngraph.simple_print.dot(mod.fg, 'bogus', 'new')
+ graph.dot(mod.fg, 'bogus', 'old')
+end
-- t2()
+t8()