Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/graph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2015-06-20 23:31:40 +0300
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2015-06-20 23:31:40 +0300
commit464a6e8e892a0ce184f27a9007fe3c5dbac39894 (patch)
tree8c8727f984de0ce7eba552f264f07b5ca8e603bb
parentb71f11aaae7712709ae9f596bcc4aa2fc2acbd1d (diff)
shape up, make tests work
-rw-r--r--AnnotatedNode.lua62
-rw-r--r--Graph.lua65
-rw-r--r--Node.lua3
-rw-r--r--graphviz.lua71
-rw-r--r--test/test_graphviz.lua35
5 files changed, 151 insertions, 85 deletions
diff --git a/AnnotatedNode.lua b/AnnotatedNode.lua
index ad85b84..cf7eeb2 100644
--- a/AnnotatedNode.lua
+++ b/AnnotatedNode.lua
@@ -14,7 +14,8 @@ instance of the AnnotatedNode. (default=2)
]]
function Node:__init(data, infoLevel)
-- level 2 is the calling function
- infoLevel = infoLevel or 2
+ infoLevel = infoLevel or 4
+ assert(type(data) == 'table' and not torch.typename(d), 'expecting a table for data')
parent.__init(self, data)
self.data.annotations = self.data.annotations or {}
@@ -52,3 +53,62 @@ function Node:graphNodeAttributes()
return self.data.annotations.graphAttributes
end
+--[[
+Returns a textual representation of the Node that can be used by graphviz library visualization.
+]]
+function Node:label()
+
+ local function getNanFlag(data)
+ if data:nElement() == 0 then
+ return ''
+ end
+ local isNan = (data:ne(data):sum() > 0)
+ if isNan then
+ return 'NaN'
+ end
+ if data:max() == math.huge then
+ return 'inf'
+ end
+ if data:min() == -math.huge then
+ return '-inf'
+ end
+ return ''
+ end
+ local function getstr(data)
+ if not data then return '' end
+ if torch.isTensor(data) then
+ local nanFlag = getNanFlag(data)
+ local tensorType = 'Tensor'
+ if data:type() ~= torch.Tensor():type() then
+ tensorType = data:type()
+ end
+ return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
+ elseif not torch.isTensor(data) and type(data) == 'table' then
+ local tstr = {}
+ for i,v in ipairs(data) do
+ table.insert(tstr, getstr(v))
+ end
+ return '{' .. table.concat(tstr,',') .. '}'
+ else
+ return tostring(data):gsub('\n','\\l')
+ end
+ end
+ local lbl = {}
+
+ for k,v in pairs(self.data) do
+ local vstr = ''
+ if k == 'annotations' then
+ -- the forwardNodeId is not displayed in the label.
+ else
+ vstr = getstr(v)
+ table.insert(lbl, k .. ' = ' .. vstr)
+ end
+ end
+
+ local desc = ''
+ if self.data.annotations.description then
+ desc = 'desc = ' .. self.data.annotations.description .. '\\n'
+ end
+ return desc .. table.concat(lbl,"\\l")
+end
+
diff --git a/Graph.lua b/Graph.lua
index 1191ca4..c93fe78 100644
--- a/Graph.lua
+++ b/Graph.lua
@@ -165,3 +165,68 @@ function Graph:leaves()
table.sort(leaves,function(a,b) return self.nodes[a] < self.nodes[b] end )
return leaves
end
+
+
+--[[
+todot function for graph class, one can use graphviz to display the graph or save on disk
+
+Args:
+* `title` - title to display on the graph
+ ]]--
+function Graph:todot(title)
+
+ local function dotEscape(str)
+ if string.find(str, '[^a-zA-Z]') then
+ -- Escape newlines and quotes.
+ local escaped = string.gsub(str, '\n', '\\n')
+ escaped = string.gsub(escaped, '"', '\\"')
+ str = '"' .. escaped .. '"'
+ end
+ return str
+ end
+ graph._dotEscape = dotEscape
+
+ --[[ Generate a string like 'color=blue tailport=s' from a table
+ (e.g. {color = 'blue', tailport = 's'}. Its up to the user to escape
+ strings properly.
+ ]]
+ local function makeAttributeString(attributes)
+ local str = {}
+ for k, v in pairs(attributes) do
+ table.insert(str, tostring(k) .. '=' .. dotEscape(tostring(v)))
+ end
+ return ' ' .. table.concat(str, ' ')
+ end
+
+ local nodes = self.nodes
+ local edges = self.edges
+ local str = {}
+ table.insert(str,'digraph G {\n')
+ if title then
+ table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
+ end
+ table.insert(str,'node [shape = oval]; ')
+ local nodelabels = {}
+ for i,node in ipairs(nodes) do
+ local nodeName
+ if node.graphNodeName then
+ nodeName = node:graphNodeName()
+ else
+ nodeName = 'Node' .. node.id
+ end
+ local l = dotEscape(nodeName .. '\n' .. node:label())
+ nodelabels[node] = 'n' .. node.id
+ local graphAttributes = ''
+ if node.graphNodeAttributes then
+ graphAttributes = makeAttributeString(node:graphNodeAttributes())
+ end
+ table.insert(str, '\n' .. nodelabels[node] .. '[label=' .. l .. graphAttributes .. '];')
+ end
+ table.insert(str,'\n')
+ for i,edge in ipairs(edges) do
+ table.insert(str,nodelabels[edge.from] .. ' -> ' .. nodelabels[edge.to] .. ';\n')
+ end
+ table.insert(str,'}')
+ return table.concat(str,'')
+end
+
diff --git a/Node.lua b/Node.lua
index 8794927..efcc769 100644
--- a/Node.lua
+++ b/Node.lua
@@ -25,8 +25,7 @@ Args:
to the given table.
]]
function Node:__init(data)
- assert(type(d) == 'table' and not torch.typename(d), 'expecting a table for data')
- self.data = d
+ self.data = data
self.id = 0
self.children = {}
self.visited = false
diff --git a/graphviz.lua b/graphviz.lua
index af24c40..847d2dd 100644
--- a/graphviz.lua
+++ b/graphviz.lua
@@ -70,8 +70,9 @@ end
-- Retrieve a node's ID based on its label string.
local function getID(node)
local label = getAttribute(node, 'label')
- local _, _, id = string.find(label, "^Node(%d+)") or string.find(label, "%((%d+)%)\\n")
- -- assert(id ~= nil, "could not get ID from node label")
+ local res = {string.find(label, "^Node(%d+)")} or {string.find(label, "%((%d+)%)\\n")}
+ local id = res[3]
+ assert(id ~= nil, "could not get ID from node label : <" .. tostring(label) .. ">")
return tonumber(id)
end
@@ -169,69 +170,3 @@ function graph.dot(g,title,fname)
return qs
end
end
-
-
-local function dotEscape(str)
- if string.find(str, '[^a-zA-Z]') then
- -- Escape newlines and quotes.
- local escaped = string.gsub(str, '\n', '\\n')
- escaped = string.gsub(escaped, '"', '\\"')
- str = '"' .. escaped .. '"'
- end
- return str
-end
-graph._dotEscape = dotEscape
-
---[[ Generate a string like 'color=blue tailport=s' from a table
- (e.g. {color = 'blue', tailport = 's'}. Its up to the user to escape
- strings properly.
-]]
-local function makeAttributeString(attributes)
- local str = {}
- for k, v in pairs(attributes) do
- table.insert(str, tostring(k) .. '=' .. dotEscape(tostring(v)))
- end
- return ' ' .. table.concat(str, ' ')
-end
-
-
-local Graph = torch.getmetatable('graph.Graph')
---[[
-todot function for graph class, one can use graphviz to display the graph or save on disk
-
-Args:
-* `title` - title to display on the graph
- ]]--
-function Graph:todot(title)
-
- local nodes = self.nodes
- local edges = self.edges
- local str = {}
- table.insert(str,'digraph G {\n')
- if title then
- table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
- end
- table.insert(str,'node [shape = oval]; ')
- local nodelabels = {}
- for i,node in ipairs(nodes) do
- local nodeName
- if node.graphNodeName then
- nodeName = node:graphNodeName()
- else
- nodeName = 'Node' .. node.id
- end
- local l = dotEscape(nodeName .. '\n' .. node:label())
- nodelabels[node] = 'n' .. node.id
- local graphAttributes = ''
- if node.graphNodeAttributes then
- graphAttributes = makeAttributeString(node:graphNodeAttributes())
- end
- table.insert(str, '\n' .. nodelabels[node] .. '[label=' .. l .. graphAttributes .. '];')
- end
- table.insert(str,'\n')
- for i,edge in ipairs(edges) do
- table.insert(str,nodelabels[edge.from] .. ' -> ' .. nodelabels[edge.to] .. ';\n')
- end
- table.insert(str,'}')
- return table.concat(str,'')
-end
diff --git a/test/test_graphviz.lua b/test/test_graphviz.lua
index 4fb250b..46bbb91 100644
--- a/test/test_graphviz.lua
+++ b/test/test_graphviz.lua
@@ -5,21 +5,29 @@ local tester = totem.Tester()
local tests = {}
function tests.test_annotateGraph()
- require 'nngraph'
- local input = nn.Identity()():annotate({name = 'Input', description = 'DescA',
- graphAttributes = {color = 'red'}})
- local hidden_a = nn.Linear(10, 10)(input):annotate({name = 'Hidden A', description = 'DescB',
- graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}})
- local hidden_b = nn.Sigmoid()(hidden_a)
- local output = nn.Linear(10, 10)(hidden_b)
- local net = nn.gModule({input}, {output})
+ local input = graph.AnnotatedNode({}):annotate({name = 'Input',
+ description = 'DescA',
+ graphAttributes = {color = 'red'}})
+ local hidden_a = graph.AnnotatedNode({}):annotate({name = 'Hidden A',
+ description = 'DescB',
+ graphAttributes = {color = 'blue',
+ fontcolor='green',
+ tooltip = 'I am green'}})
+ local hidden_b = graph.AnnotatedNode({})
+ local output = graph.AnnotatedNode({})
+
+ hidden_a:add(input)
+ hidden_b:add(hidden_a)
+ output:add(hidden_b)
+ local bg = output:graph()
+ local fg = bg:reverse()
tester:assert(hidden_a:label():match('DescB'))
local fg_tmpfile = os.tmpname()
local bg_tmpfile = os.tmpname()
- graph.dot(net.fg, 'Test', fg_tmpfile)
- graph.dot(net.fg, 'Test BG', bg_tmpfile)
+ graph.dot(fg, 'Test', fg_tmpfile)
+ graph.dot(bg, 'Test BG', bg_tmpfile)
local function checkDotFile(tmpfile)
local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all")
@@ -46,10 +54,9 @@ function tests.layout()
local positions = graph.graphvizLayout(g, 'dot')
local xs = positions:select(2, 1)
local ys = positions:select(2, 2)
- tester:assertlt(xs:add(-xs:mean()):norm(), 1e-3,
- "x coordinates should be the same")
- tester:assertTensorEq(ys, torch.sort(ys, true), 1e-3,
- "y coordinates should be ordered")
+
+ tester:assertlt(xs:add(-xs:mean()):norm(), 1e-3, "x coordinates should be the same")
+ tester:assertTensorEq(ys, torch.sort(ys, true), 1e-3, "y coordinates should be ordered")
end
function tests.testDotEscape()