diff options
author | koray kavukcuoglu <koray@kavukcuoglu.org> | 2015-06-20 23:31:40 +0300 |
---|---|---|
committer | koray kavukcuoglu <koray@kavukcuoglu.org> | 2015-06-20 23:31:40 +0300 |
commit | 464a6e8e892a0ce184f27a9007fe3c5dbac39894 (patch) | |
tree | 8c8727f984de0ce7eba552f264f07b5ca8e603bb | |
parent | b71f11aaae7712709ae9f596bcc4aa2fc2acbd1d (diff) |
shape up, make tests work
-rw-r--r-- | AnnotatedNode.lua | 62 | ||||
-rw-r--r-- | Graph.lua | 65 | ||||
-rw-r--r-- | Node.lua | 3 | ||||
-rw-r--r-- | graphviz.lua | 71 | ||||
-rw-r--r-- | test/test_graphviz.lua | 35 |
5 files changed, 151 insertions, 85 deletions
diff --git a/AnnotatedNode.lua b/AnnotatedNode.lua index ad85b84..cf7eeb2 100644 --- a/AnnotatedNode.lua +++ b/AnnotatedNode.lua @@ -14,7 +14,8 @@ instance of the AnnotatedNode. (default=2) ]] function Node:__init(data, infoLevel) -- level 2 is the calling function - infoLevel = infoLevel or 2 + infoLevel = infoLevel or 4 + assert(type(data) == 'table' and not torch.typename(d), 'expecting a table for data') parent.__init(self, data) self.data.annotations = self.data.annotations or {} @@ -52,3 +53,62 @@ function Node:graphNodeAttributes() return self.data.annotations.graphAttributes end +--[[ +Returns a textual representation of the Node that can be used by graphviz library visualization. +]] +function Node:label() + + local function getNanFlag(data) + if data:nElement() == 0 then + return '' + end + local isNan = (data:ne(data):sum() > 0) + if isNan then + return 'NaN' + end + if data:max() == math.huge then + return 'inf' + end + if data:min() == -math.huge then + return '-inf' + end + return '' + end + local function getstr(data) + if not data then return '' end + if torch.isTensor(data) then + local nanFlag = getNanFlag(data) + local tensorType = 'Tensor' + if data:type() ~= torch.Tensor():type() then + tensorType = data:type() + end + return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag + elseif not torch.isTensor(data) and type(data) == 'table' then + local tstr = {} + for i,v in ipairs(data) do + table.insert(tstr, getstr(v)) + end + return '{' .. table.concat(tstr,',') .. '}' + else + return tostring(data):gsub('\n','\\l') + end + end + local lbl = {} + + for k,v in pairs(self.data) do + local vstr = '' + if k == 'annotations' then + -- the forwardNodeId is not displayed in the label. + else + vstr = getstr(v) + table.insert(lbl, k .. ' = ' .. vstr) + end + end + + local desc = '' + if self.data.annotations.description then + desc = 'desc = ' .. self.data.annotations.description .. '\\n' + end + return desc .. table.concat(lbl,"\\l") +end + @@ -165,3 +165,68 @@ function Graph:leaves() table.sort(leaves,function(a,b) return self.nodes[a] < self.nodes[b] end ) return leaves end + + +--[[ +todot function for graph class, one can use graphviz to display the graph or save on disk + +Args: +* `title` - title to display on the graph + ]]-- +function Graph:todot(title) + + local function dotEscape(str) + if string.find(str, '[^a-zA-Z]') then + -- Escape newlines and quotes. + local escaped = string.gsub(str, '\n', '\\n') + escaped = string.gsub(escaped, '"', '\\"') + str = '"' .. escaped .. '"' + end + return str + end + graph._dotEscape = dotEscape + + --[[ Generate a string like 'color=blue tailport=s' from a table + (e.g. {color = 'blue', tailport = 's'}. Its up to the user to escape + strings properly. + ]] + local function makeAttributeString(attributes) + local str = {} + for k, v in pairs(attributes) do + table.insert(str, tostring(k) .. '=' .. dotEscape(tostring(v))) + end + return ' ' .. table.concat(str, ' ') + end + + local nodes = self.nodes + local edges = self.edges + local str = {} + table.insert(str,'digraph G {\n') + if title then + table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n') + end + table.insert(str,'node [shape = oval]; ') + local nodelabels = {} + for i,node in ipairs(nodes) do + local nodeName + if node.graphNodeName then + nodeName = node:graphNodeName() + else + nodeName = 'Node' .. node.id + end + local l = dotEscape(nodeName .. '\n' .. node:label()) + nodelabels[node] = 'n' .. node.id + local graphAttributes = '' + if node.graphNodeAttributes then + graphAttributes = makeAttributeString(node:graphNodeAttributes()) + end + table.insert(str, '\n' .. nodelabels[node] .. '[label=' .. l .. graphAttributes .. '];') + end + table.insert(str,'\n') + for i,edge in ipairs(edges) do + table.insert(str,nodelabels[edge.from] .. ' -> ' .. nodelabels[edge.to] .. ';\n') + end + table.insert(str,'}') + return table.concat(str,'') +end + @@ -25,8 +25,7 @@ Args: to the given table. ]] function Node:__init(data) - assert(type(d) == 'table' and not torch.typename(d), 'expecting a table for data') - self.data = d + self.data = data self.id = 0 self.children = {} self.visited = false diff --git a/graphviz.lua b/graphviz.lua index af24c40..847d2dd 100644 --- a/graphviz.lua +++ b/graphviz.lua @@ -70,8 +70,9 @@ end -- Retrieve a node's ID based on its label string. local function getID(node) local label = getAttribute(node, 'label') - local _, _, id = string.find(label, "^Node(%d+)") or string.find(label, "%((%d+)%)\\n") - -- assert(id ~= nil, "could not get ID from node label") + local res = {string.find(label, "^Node(%d+)")} or {string.find(label, "%((%d+)%)\\n")} + local id = res[3] + assert(id ~= nil, "could not get ID from node label : <" .. tostring(label) .. ">") return tonumber(id) end @@ -169,69 +170,3 @@ function graph.dot(g,title,fname) return qs end end - - -local function dotEscape(str) - if string.find(str, '[^a-zA-Z]') then - -- Escape newlines and quotes. - local escaped = string.gsub(str, '\n', '\\n') - escaped = string.gsub(escaped, '"', '\\"') - str = '"' .. escaped .. '"' - end - return str -end -graph._dotEscape = dotEscape - ---[[ Generate a string like 'color=blue tailport=s' from a table - (e.g. {color = 'blue', tailport = 's'}. Its up to the user to escape - strings properly. -]] -local function makeAttributeString(attributes) - local str = {} - for k, v in pairs(attributes) do - table.insert(str, tostring(k) .. '=' .. dotEscape(tostring(v))) - end - return ' ' .. table.concat(str, ' ') -end - - -local Graph = torch.getmetatable('graph.Graph') ---[[ -todot function for graph class, one can use graphviz to display the graph or save on disk - -Args: -* `title` - title to display on the graph - ]]-- -function Graph:todot(title) - - local nodes = self.nodes - local edges = self.edges - local str = {} - table.insert(str,'digraph G {\n') - if title then - table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n') - end - table.insert(str,'node [shape = oval]; ') - local nodelabels = {} - for i,node in ipairs(nodes) do - local nodeName - if node.graphNodeName then - nodeName = node:graphNodeName() - else - nodeName = 'Node' .. node.id - end - local l = dotEscape(nodeName .. '\n' .. node:label()) - nodelabels[node] = 'n' .. node.id - local graphAttributes = '' - if node.graphNodeAttributes then - graphAttributes = makeAttributeString(node:graphNodeAttributes()) - end - table.insert(str, '\n' .. nodelabels[node] .. '[label=' .. l .. graphAttributes .. '];') - end - table.insert(str,'\n') - for i,edge in ipairs(edges) do - table.insert(str,nodelabels[edge.from] .. ' -> ' .. nodelabels[edge.to] .. ';\n') - end - table.insert(str,'}') - return table.concat(str,'') -end diff --git a/test/test_graphviz.lua b/test/test_graphviz.lua index 4fb250b..46bbb91 100644 --- a/test/test_graphviz.lua +++ b/test/test_graphviz.lua @@ -5,21 +5,29 @@ local tester = totem.Tester() local tests = {} function tests.test_annotateGraph() - require 'nngraph' - local input = nn.Identity()():annotate({name = 'Input', description = 'DescA', - graphAttributes = {color = 'red'}}) - local hidden_a = nn.Linear(10, 10)(input):annotate({name = 'Hidden A', description = 'DescB', - graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}}) - local hidden_b = nn.Sigmoid()(hidden_a) - local output = nn.Linear(10, 10)(hidden_b) - local net = nn.gModule({input}, {output}) + local input = graph.AnnotatedNode({}):annotate({name = 'Input', + description = 'DescA', + graphAttributes = {color = 'red'}}) + local hidden_a = graph.AnnotatedNode({}):annotate({name = 'Hidden A', + description = 'DescB', + graphAttributes = {color = 'blue', + fontcolor='green', + tooltip = 'I am green'}}) + local hidden_b = graph.AnnotatedNode({}) + local output = graph.AnnotatedNode({}) + + hidden_a:add(input) + hidden_b:add(hidden_a) + output:add(hidden_b) + local bg = output:graph() + local fg = bg:reverse() tester:assert(hidden_a:label():match('DescB')) local fg_tmpfile = os.tmpname() local bg_tmpfile = os.tmpname() - graph.dot(net.fg, 'Test', fg_tmpfile) - graph.dot(net.fg, 'Test BG', bg_tmpfile) + graph.dot(fg, 'Test', fg_tmpfile) + graph.dot(bg, 'Test BG', bg_tmpfile) local function checkDotFile(tmpfile) local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all") @@ -46,10 +54,9 @@ function tests.layout() local positions = graph.graphvizLayout(g, 'dot') local xs = positions:select(2, 1) local ys = positions:select(2, 2) - tester:assertlt(xs:add(-xs:mean()):norm(), 1e-3, - "x coordinates should be the same") - tester:assertTensorEq(ys, torch.sort(ys, true), 1e-3, - "y coordinates should be ordered") + + tester:assertlt(xs:add(-xs:mean()):norm(), 1e-3, "x coordinates should be the same") + tester:assertTensorEq(ys, torch.sort(ys, true), 1e-3, "y coordinates should be ordered") end function tests.testDotEscape() |