From eb572b1af822afcbf528b7719ca4fc9aa9e69c37 Mon Sep 17 00:00:00 2001 From: Clement Farabet Date: Fri, 4 Sep 2015 16:59:12 -0400 Subject: Whitespace cleanup. --- Edge.lua | 8 +- Node.lua | 144 ++++++++++----------- graphviz.lua | 250 +++++++++++++++++------------------ init.lua | 343 ++++++++++++++++++++++++------------------------- test/test_graph.lua | 216 +++++++++++++++---------------- test/test_graphviz.lua | 44 +++---- 6 files changed, 502 insertions(+), 503 deletions(-) diff --git a/Edge.lua b/Edge.lua index f277613..1ffd670 100644 --- a/Edge.lua +++ b/Edge.lua @@ -1,10 +1,10 @@ --[[ - A Directed Edge class - No methods, just two fields, from and to. +A Directed Edge class +No methods, just two fields, from and to. ]]-- local Edge = torch.class('graph.Edge') function Edge:__init(from,to) - self.from = from - self.to = to + self.from = from + self.to = to end diff --git a/Node.lua b/Node.lua index 459f9a8..1f83c18 100644 --- a/Node.lua +++ b/Node.lua @@ -1,105 +1,105 @@ --[[ - Node class. This class is generally used with edge to add edges into a graph. - graph:add(graph.Edge(graph.Node(),graph.Node())) +Node class. This class is generally used with edge to add edges into a graph. +graph:add(graph.Edge(graph.Node(),graph.Node())) - But, one can also easily use this node class to create a graph. It will register - all the edges into its children table and one can parse the graph from any given node. - The drawback is there will be no global edge table and node table, which is mostly useful - to run algorithms on graphs. If all you need is just a data structure to store data and - run DFS, BFS over the graph, then this method is also quick and nice. +But, one can also easily use this node class to create a graph. It will register +all the edges into its children table and one can parse the graph from any given node. +The drawback is there will be no global edge table and node table, which is mostly useful +to run algorithms on graphs. If all you need is just a data structure to store data and +run DFS, BFS over the graph, then this method is also quick and nice. --]] local Node = torch.class('graph.Node') function Node:__init(d,p) - self.data = d - self.id = 0 - self.children = {} - self.visited = false - self.marked = false + self.data = d + self.id = 0 + self.children = {} + self.visited = false + self.marked = false end function Node:add(child) - local children = self.children - if type(child) == 'table' and not torch.typename(child) then - for i,v in ipairs(child) do - self:add(v) - end - elseif not children[child] then - table.insert(children,child) - children[child] = #children - end + local children = self.children + if type(child) == 'table' and not torch.typename(child) then + for i,v in ipairs(child) do + self:add(v) + end + elseif not children[child] then + table.insert(children,child) + children[child] = #children + end end -- visitor function Node:visit(pre_func,post_func) - if not self.visited then - if pre_func then pre_func(self) end - for i,child in ipairs(self.children) do - child:visit(pre_func, post_func) - end - if post_func then post_func(self) end - end + if not self.visited then + if pre_func then pre_func(self) end + for i,child in ipairs(self.children) do + child:visit(pre_func, post_func) + end + if post_func then post_func(self) end + end end function Node:label() - return tostring(self.data) + return tostring(self.data) end -- Create a graph from the Node traversal function Node:graph() - local g = graph.Graph() - local function build_graph(node) - for i,child in ipairs(node.children) do - g:add(graph.Edge(node,child)) - end - end - self:bfs(build_graph) - return g + local g = graph.Graph() + local function build_graph(node) + for i,child in ipairs(node.children) do + g:add(graph.Edge(node,child)) + end + end + self:bfs(build_graph) + return g end function Node:dfs_dirty(func) - local visitednodes = {} - local dfs_func = function(node) - func(node) - table.insert(visitednodes,node) - end - local dfs_func_pre = function(node) - node.visited = true - end - self:visit(dfs_func_pre, dfs_func) - return visitednodes + local visitednodes = {} + local dfs_func = function(node) + func(node) + table.insert(visitednodes,node) + end + local dfs_func_pre = function(node) + node.visited = true + end + self:visit(dfs_func_pre, dfs_func) + return visitednodes end function Node:dfs(func) - for i,node in ipairs(self:dfs_dirty(func)) do - node.visited = false - end + for i,node in ipairs(self:dfs_dirty(func)) do + node.visited = false + end end function Node:bfs_dirty(func) - local visitednodes = {} - local bfsnodes = {} - local bfs_func = function(node) - func(node) - for i,child in ipairs(node.children) do - if not child.marked then - child.marked = true - table.insert(bfsnodes,child) - end - end - end - table.insert(bfsnodes,self) - self.marked = true - while #bfsnodes > 0 do - local node = table.remove(bfsnodes,1) - table.insert(visitednodes,node) - bfs_func(node) - end - return visitednodes + local visitednodes = {} + local bfsnodes = {} + local bfs_func = function(node) + func(node) + for i,child in ipairs(node.children) do + if not child.marked then + child.marked = true + table.insert(bfsnodes,child) + end + end + end + table.insert(bfsnodes,self) + self.marked = true + while #bfsnodes > 0 do + local node = table.remove(bfsnodes,1) + table.insert(visitednodes,node) + bfs_func(node) + end + return visitednodes end function Node:bfs(func) - for i,node in ipairs(self:bfs_dirty(func)) do - node.marked = false - end + for i,node in ipairs(self:bfs_dirty(func)) do + node.marked = false + end end diff --git a/graphviz.lua b/graphviz.lua index 785e325..bff189e 100644 --- a/graphviz.lua +++ b/graphviz.lua @@ -9,90 +9,90 @@ local cgraph ffiOk, ffi = pcall(require, 'ffi') if ffiOk then - ffi.cdef[[ -typedef struct FILE FILE; - -typedef struct Agraph_s Agraph_t; -typedef struct Agnode_s Agnode_t; - -extern Agraph_t *agmemread(const char *cp); -extern char *agget(void *obj, char *name); -extern int agclose(Agraph_t * g); -extern Agnode_t *agfstnode(Agraph_t * g); -extern Agnode_t *agnxtnode(Agraph_t * g, Agnode_t * n); -extern Agnode_t *aglstnode(Agraph_t * g); -extern Agnode_t *agprvnode(Agraph_t * g, Agnode_t * n); - -typedef struct Agraph_s graph_t; -typedef struct GVJ_s GVJ_t; -typedef struct GVG_s GVG_t; -typedef struct GVC_s GVC_t; -extern GVC_t *gvContext(void); -extern int gvLayout(GVC_t *context, graph_t *g, const char *engine); -extern int gvRender(GVC_t *context, graph_t *g, const char *format, FILE *out); -extern int gvFreeLayout(GVC_t *context, graph_t *g); -extern int gvFreeContext(GVC_t *context); -]] - graphvizOk, graphviz = pcall(function() return ffi.load('libgvc') end) - if not graphvizOk then + ffi.cdef[[ + typedef struct FILE FILE; + + typedef struct Agraph_s Agraph_t; + typedef struct Agnode_s Agnode_t; + + extern Agraph_t *agmemread(const char *cp); + extern char *agget(void *obj, char *name); + extern int agclose(Agraph_t * g); + extern Agnode_t *agfstnode(Agraph_t * g); + extern Agnode_t *agnxtnode(Agraph_t * g, Agnode_t * n); + extern Agnode_t *aglstnode(Agraph_t * g); + extern Agnode_t *agprvnode(Agraph_t * g, Agnode_t * n); + + typedef struct Agraph_s graph_t; + typedef struct GVJ_s GVJ_t; + typedef struct GVG_s GVG_t; + typedef struct GVC_s GVC_t; + extern GVC_t *gvContext(void); + extern int gvLayout(GVC_t *context, graph_t *g, const char *engine); + extern int gvRender(GVC_t *context, graph_t *g, const char *format, FILE *out); + extern int gvFreeLayout(GVC_t *context, graph_t *g); + extern int gvFreeContext(GVC_t *context); + ]] + graphvizOk, graphviz = pcall(function() return ffi.load('libgvc') end) + if not graphvizOk then graphvizOk, graphviz = pcall(function() return ffi.load('libgvc.so.6') end) - end + end - cgraphOk, cgraph = pcall(function() return ffi.load('libcgraph') end) - if not cgraphOk then + cgraphOk, cgraph = pcall(function() return ffi.load('libcgraph') end) + if not cgraphOk then cgraphOk, cgraph = pcall(function() return ffi.load('libcgraph.so.6') end) - end + end else - graphvizOk = false - cgraphOk = false + graphvizOk = false + cgraphOk = false end -- Retrieve attribute data from a graphviz object. local function getAttribute(obj, name) - local res = cgraph.agget(obj, ffi.cast("char*", name)) - assert(res ~= ffi.cast("char*", nil), 'could not get attr ' .. name) - return ffi.string(res) + local res = cgraph.agget(obj, ffi.cast("char*", name)) + assert(res ~= ffi.cast("char*", nil), 'could not get attr ' .. name) + return ffi.string(res) end -- Iterate through nodes of a graphviz graph. local function nodeIterator(graph) - local node = cgraph.agfstnode(graph) - local nextNode - return function() - if node == nil then return end - if node == cgraph.aglstnode(graph) then nextNode = nil end - nextNode = cgraph.agnxtnode(graph, node) - local result = node - node = nextNode - return result - end + local node = cgraph.agfstnode(graph) + local nextNode + return function() + if node == nil then return end + if node == cgraph.aglstnode(graph) then nextNode = nil end + nextNode = cgraph.agnxtnode(graph, node) + local result = node + node = nextNode + return result + end end -- Convert a string of comma-separated numbers to actual numbers. local function extractNumbers(n, attr) - local res = {} - for number in string.gmatch(attr, "[^%,]+") do - table.insert(res, tonumber(number)) - end - assert(#res == n, "attribute is not of expected form") - return unpack(res) + local res = {} + for number in string.gmatch(attr, "[^%,]+") do + table.insert(res, tonumber(number)) + end + assert(#res == n, "attribute is not of expected form") + return unpack(res) end -- Transform from graphviz coordinates to unit square. local function getRelativePosition(node, bbox) - local x0, y0, w, h = unpack(bbox) - local x, y = extractNumbers(2, getAttribute(node, 'pos')) - local xt = (x - x0) / w - local yt = (y - y0) / h - assert(xt >= 0 and xt <= 1, "bad x coordinate") - assert(yt >= 0 and yt <= 1, "bad y coordinate") - return xt, yt + local x0, y0, w, h = unpack(bbox) + local x, y = extractNumbers(2, getAttribute(node, 'pos')) + local xt = (x - x0) / w + local yt = (y - y0) / h + assert(xt >= 0 and xt <= 1, "bad x coordinate") + assert(yt >= 0 and yt <= 1, "bad y coordinate") + return xt, yt end -- Retrieve a node's ID based on its label string. local function getID(node) - local label = getAttribute(node, 'label') - local res = {string.find(label, "^Node(%d+)")} or {string.find(label, "%((%d+)%)\\n")} - local id = res[3] - assert(id ~= nil, "could not get ID from node label : <" .. tostring(label) .. ">") - return tonumber(id) + local label = getAttribute(node, 'label') + local res = {string.find(label, "^Node(%d+)")} or {string.find(label, "%((%d+)%)\\n")} + local id = res[3] + assert(id ~= nil, "could not get ID from node label : <" .. tostring(label) .. ">") + return tonumber(id) end --[[ Lay out a graph and return the positions of the nodes. @@ -109,55 +109,55 @@ Coordinates are in the interval [0, 1]. ]] function graph.graphvizLayout(g, algorithm) - if not graphvizOk or not cgraphOk then - error("graphviz library could not be loaded.") - end - local nNodes = #g.nodes - local context = graphviz.gvContext() - local graphvizGraph = cgraph.agmemread(g:todot()) - local algorithm = algorithm or "dot" - assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm), - "graphviz layout failed") - assert(0 == graphviz.gvRender(context, graphvizGraph, algorithm, nil), - "graphviz render failed") - - -- Extract bounding box. - local x0, y0, x1, y1 = extractNumbers(4, - getAttribute(graphvizGraph, 'bb'), ",") - local w = x1 - x0 - local h = y1 - y0 - local bbox = { x0, y0, w, h } - - -- Extract node positions. - local positions = torch.zeros(nNodes, 2) - for node in nodeIterator(graphvizGraph) do - local id = getID(node) - local x, y = getRelativePosition(node, bbox) - positions[id][1] = x - positions[id][2] = y - end - - -- Clean up. - graphviz.gvFreeLayout(context, graphvizGraph) - cgraph.agclose(graphvizGraph) - graphviz.gvFreeContext(context) - return positions + if not graphvizOk or not cgraphOk then + error("graphviz library could not be loaded.") + end + local nNodes = #g.nodes + local context = graphviz.gvContext() + local graphvizGraph = cgraph.agmemread(g:todot()) + local algorithm = algorithm or "dot" + assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm), + "graphviz layout failed") + assert(0 == graphviz.gvRender(context, graphvizGraph, algorithm, nil), + "graphviz render failed") + + -- Extract bounding box. + local x0, y0, x1, y1 = extractNumbers(4, + getAttribute(graphvizGraph, 'bb'), ",") + local w = x1 - x0 + local h = y1 - y0 + local bbox = { x0, y0, w, h } + + -- Extract node positions. + local positions = torch.zeros(nNodes, 2) + for node in nodeIterator(graphvizGraph) do + local id = getID(node) + local x, y = getRelativePosition(node, bbox) + positions[id][1] = x + positions[id][2] = y + end + + -- Clean up. + graphviz.gvFreeLayout(context, graphvizGraph) + cgraph.agclose(graphvizGraph) + graphviz.gvFreeContext(context) + return positions end function graph.graphvizFile(g, algorithm, fname) - algorithm = algorithm or 'dot' - local _,_,rendertype = fname:reverse():find('(%a+)%.%w+') - rendertype = rendertype:reverse() - - local context = graphviz.gvContext() - local graphvizGraph = cgraph.agmemread(g:todot()) - assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm), - "graphviz layout failed") - assert(0 == graphviz.gvRender(context, graphvizGraph, rendertype, io.open(fname, 'w')), - "graphviz render failed") - graphviz.gvFreeLayout(context, graphvizGraph) - cgraph.agclose(graphvizGraph) - graphviz.gvFreeContext(context) + algorithm = algorithm or 'dot' + local _,_,rendertype = fname:reverse():find('(%a+)%.%w+') + rendertype = rendertype:reverse() + + local context = graphviz.gvContext() + local graphvizGraph = cgraph.agmemread(g:todot()) + assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm), + "graphviz layout failed") + assert(0 == graphviz.gvRender(context, graphvizGraph, rendertype, io.open(fname, 'w')), + "graphviz render failed") + graphviz.gvFreeLayout(context, graphvizGraph) + cgraph.agclose(graphvizGraph) + graphviz.gvFreeContext(context) end --[[ @@ -167,25 +167,25 @@ Args: * `g` - graph to display * `title` - Title to display in the graph * `fname` - [optional] if given it should contain a file name without an extension, - the graph is saved on disk as fname.svg and display is not shown. If not given - the graph is shown on qt display (you need to have qtsvg installed and running qlua) +the graph is saved on disk as fname.svg and display is not shown. If not given +the graph is shown on qt display (you need to have qtsvg installed and running qlua) Returns: * `qs` - the window handle for the qt display (if fname given) or nil ]] function graph.dot(g,title,fname) - local qt_display = fname == nil - fname = fname or os.tmpname() - local fnsvg = fname .. '.svg' - local fndot = fname .. '.dot' - graph.graphvizFile(g, 'dot', fnsvg) - graph.graphvizFile(g, 'dot', fndot) - if qt_display then - require 'qtsvg' - local qs = qt.QSvgWidget(fnsvg) - qs:show() - os.remove(fnsvg) - os.remove(fndot) - return qs - end + local qt_display = fname == nil + fname = fname or os.tmpname() + local fnsvg = fname .. '.svg' + local fndot = fname .. '.dot' + graph.graphvizFile(g, 'dot', fnsvg) + graph.graphvizFile(g, 'dot', fndot) + if qt_display then + require 'qtsvg' + local qs = qt.QSvgWidget(fnsvg) + qs:show() + os.remove(fnsvg) + os.remove(fndot) + return qs + end end diff --git a/init.lua b/init.lua index 3cf6e40..f3f25dd 100644 --- a/init.lua +++ b/init.lua @@ -8,227 +8,226 @@ torch.include('graph','Edge.lua') --[[ - Defines a graph and general operations on grpahs like topsort, - connected components, ... - uses two tables, one for nodes, one for edges +Defines a graph and general operations on grpahs like topsort, +connected components, ... +uses two tables, one for nodes, one for edges ]]-- local Graph = torch.class('graph.Graph') function Graph:__init() - self.nodes = {} - self.edges = {} + self.nodes = {} + self.edges = {} end -- add a new edge into the graph. -- an edge has two fields, from and to that are inserted into the -- nodes table. the edge itself is inserted into the edges table. function Graph:add(edge) - if type(edge) ~= 'table' then - error('graph.Edge or {graph.Edges} expected') - end - if torch.typename(edge) then - -- add edge - if not self.edges[edge] then - table.insert(self.edges,edge) - self.edges[edge] = #self.edges - end - -- add from node - if not self.nodes[edge.from] then - table.insert(self.nodes,edge.from) - self.nodes[edge.from] = #self.nodes - end - -- add to node - if not self.nodes[edge.to] then - table.insert(self.nodes,edge.to) - self.nodes[edge.to] = #self.nodes - end - -- add the edge to the node for parsing in nodes - edge.from:add(edge.to) - edge.from.id = self.nodes[edge.from] - edge.to.id = self.nodes[edge.to] - else - for i,e in ipairs(edge) do - self:add(e) - end - end + if type(edge) ~= 'table' then + error('graph.Edge or {graph.Edges} expected') + end + if torch.typename(edge) then + -- add edge + if not self.edges[edge] then + table.insert(self.edges,edge) + self.edges[edge] = #self.edges + end + -- add from node + if not self.nodes[edge.from] then + table.insert(self.nodes,edge.from) + self.nodes[edge.from] = #self.nodes + end + -- add to node + if not self.nodes[edge.to] then + table.insert(self.nodes,edge.to) + self.nodes[edge.to] = #self.nodes + end + -- add the edge to the node for parsing in nodes + edge.from:add(edge.to) + edge.from.id = self.nodes[edge.from] + edge.to.id = self.nodes[edge.to] + else + for i,e in ipairs(edge) do + self:add(e) + end + end end -- Clone a Graph -- this will create new nodes, but will share the data. -- Note that primitive data types like numbers can not be shared function Graph:clone() - local clone = graph.Graph() - local nodes = {} - for i,n in ipairs(self.nodes) do - table.insert(nodes,n.new(n.data)) - end - for i,e in ipairs(self.edges) do - local from = nodes[self.nodes[e.from]] - local to = nodes[self.nodes[e.to]] - clone:add(e.new(from,to)) - end - return clone + local clone = graph.Graph() + local nodes = {} + for i,n in ipairs(self.nodes) do + table.insert(nodes,n.new(n.data)) + end + for i,e in ipairs(self.edges) do + local from = nodes[self.nodes[e.from]] + local to = nodes[self.nodes[e.to]] + clone:add(e.new(from,to)) + end + return clone end -- It returns a new graph where the edges are reversed. -- The nodes share the data. Note that primitive data types can -- not be shared. function Graph:reverse() - local rg = graph.Graph() - local mapnodes = {} - for i,e in ipairs(self.edges) do - mapnodes[e.from] = mapnodes[e.from] or e.from.new(e.from.data) - mapnodes[e.to] = mapnodes[e.to] or e.to.new(e.to.data) - local from = mapnodes[e.from] - local to = mapnodes[e.to] - rg:add(e.new(to,from)) - end - return rg,mapnodes + local rg = graph.Graph() + local mapnodes = {} + for i,e in ipairs(self.edges) do + mapnodes[e.from] = mapnodes[e.from] or e.from.new(e.from.data) + mapnodes[e.to] = mapnodes[e.to] or e.to.new(e.to.data) + local from = mapnodes[e.from] + local to = mapnodes[e.to] + rg:add(e.new(to,from)) + end + return rg,mapnodes end --[[ - Topological Sort - ** This is not finished. OK for graphs with single root. +Topological Sort +** This is not finished. OK for graphs with single root. ]]-- function Graph:topsort() - -- reverse the graph - local rg,map = self:reverse() - local rmap = {} - for k,v in pairs(map) do - rmap[v] = k - end - - -- work on the sorted graph - local sortednodes = {} - local rootnodes = rg:roots() - - if #rootnodes == 0 then - error('Graph has cycles') - end - - -- run - for i,root in ipairs(rootnodes) do - root:dfs(function(node) table.insert(sortednodes,rmap[node]) end) - end - - if #sortednodes ~= #self.nodes then - error('Graph has cycles') - end - return sortednodes,rg,rootnodes + -- reverse the graph + local rg,map = self:reverse() + local rmap = {} + for k,v in pairs(map) do + rmap[v] = k + end + + -- work on the sorted graph + local sortednodes = {} + local rootnodes = rg:roots() + + if #rootnodes == 0 then + error('Graph has cycles') + end + + -- run + for i,root in ipairs(rootnodes) do + root:dfs(function(node) table.insert(sortednodes,rmap[node]) end) + end + + if #sortednodes ~= #self.nodes then + error('Graph has cycles') + end + return sortednodes,rg,rootnodes end -- find root nodes function Graph:roots() - local edges = self.edges - local rootnodes = {} - for i,edge in ipairs(edges) do - --table.insert(rootnodes,edge.from) - if not rootnodes[edge.from] then - rootnodes[edge.from] = #rootnodes+1 - end - end - for i,edge in ipairs(edges) do - if rootnodes[edge.to] then - rootnodes[edge.to] = nil - end - end - local roots = {} - for root,i in pairs(rootnodes) do - table.insert(roots, root) - end - table.sort(roots,function(a,b) return self.nodes[a] < self.nodes[b] end ) - return roots + local edges = self.edges + local rootnodes = {} + for i,edge in ipairs(edges) do + --table.insert(rootnodes,edge.from) + if not rootnodes[edge.from] then + rootnodes[edge.from] = #rootnodes+1 + end + end + for i,edge in ipairs(edges) do + if rootnodes[edge.to] then + rootnodes[edge.to] = nil + end + end + local roots = {} + for root,i in pairs(rootnodes) do + table.insert(roots, root) + end + table.sort(roots,function(a,b) return self.nodes[a] < self.nodes[b] end ) + return roots end -- find root nodes function Graph:leaves() - local edges = self.edges - local leafnodes = {} - for i,edge in ipairs(edges) do - --table.insert(rootnodes,edge.from) - if not leafnodes[edge.to] then - leafnodes[edge.to] = #leafnodes+1 - end - end - for i,edge in ipairs(edges) do - if leafnodes[edge.from] then - leafnodes[edge.from] = nil - end - end - local leaves = {} - for leaf,i in pairs(leafnodes) do - table.insert(leaves, leaf) - end - table.sort(leaves,function(a,b) return self.nodes[a] < self.nodes[b] end ) - return leaves + local edges = self.edges + local leafnodes = {} + for i,edge in ipairs(edges) do + --table.insert(rootnodes,edge.from) + if not leafnodes[edge.to] then + leafnodes[edge.to] = #leafnodes+1 + end + end + for i,edge in ipairs(edges) do + if leafnodes[edge.from] then + leafnodes[edge.from] = nil + end + end + local leaves = {} + for leaf,i in pairs(leafnodes) do + table.insert(leaves, leaf) + end + table.sort(leaves,function(a,b) return self.nodes[a] < self.nodes[b] end ) + return leaves end function graph._dotEscape(str) - if string.find(str, '[^a-zA-Z]') then - -- Escape newlines and quotes. - local escaped = string.gsub(str, '\n', '\\n') - escaped = string.gsub(escaped, '"', '\\"') - str = '"' .. escaped .. '"' - end - return str + if string.find(str, '[^a-zA-Z]') then + -- Escape newlines and quotes. + local escaped = string.gsub(str, '\n', '\\n') + escaped = string.gsub(escaped, '"', '\\"') + str = '"' .. escaped .. '"' + end + return str end --[[ Generate a string like 'color=blue tailport=s' from a table - (e.g. {color = 'blue', tailport = 's'}. Its up to the user to escape - strings properly. +(e.g. {color = 'blue', tailport = 's'}. Its up to the user to escape +strings properly. ]] local function makeAttributeString(attributes) - local str = {} - local keys = {} - for k, _ in pairs(attributes) do - table.insert(keys, k) - end - table.sort(keys) - for _, k in ipairs(keys) do - local v = attributes[k] - table.insert(str, tostring(k) .. '=' .. graph._dotEscape(tostring(v))) - end - return ' ' .. table.concat(str, ' ') + local str = {} + local keys = {} + for k, _ in pairs(attributes) do + table.insert(keys, k) + end + table.sort(keys) + for _, k in ipairs(keys) do + local v = attributes[k] + table.insert(str, tostring(k) .. '=' .. graph._dotEscape(tostring(v))) + end + return ' ' .. table.concat(str, ' ') end function Graph:todot(title) - local nodes = self.nodes - local edges = self.edges - local str = {} - table.insert(str,'digraph G {\n') - if title then - table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n') - end - table.insert(str,'node [shape = oval]; ') - local nodelabels = {} - for i,node in ipairs(nodes) do - local nodeName - if node.graphNodeName then - nodeName = node:graphNodeName() - else - nodeName = 'Node' .. node.id - end - local l = graph._dotEscape(nodeName .. '\n' .. node:label()) - nodelabels[node] = 'n' .. node.id - local graphAttributes = '' - if node.graphNodeAttributes then - graphAttributes = makeAttributeString( - node:graphNodeAttributes()) - end - table.insert(str, - '\n' .. nodelabels[node] .. - '[label=' .. l .. graphAttributes .. '];') - end - table.insert(str,'\n') - for i,edge in ipairs(edges) do - table.insert(str,nodelabels[edge.from] .. ' -> ' .. nodelabels[edge.to] .. ';\n') - end - table.insert(str,'}') - return table.concat(str,'') + local nodes = self.nodes + local edges = self.edges + local str = {} + table.insert(str,'digraph G {\n') + if title then + table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n') + end + table.insert(str,'node [shape = oval]; ') + local nodelabels = {} + for i,node in ipairs(nodes) do + local nodeName + if node.graphNodeName then + nodeName = node:graphNodeName() + else + nodeName = 'Node' .. node.id + end + local l = graph._dotEscape(nodeName .. '\n' .. node:label()) + nodelabels[node] = 'n' .. node.id + local graphAttributes = '' + if node.graphNodeAttributes then + graphAttributes = makeAttributeString( + node:graphNodeAttributes()) + end + table.insert(str, + '\n' .. nodelabels[node] .. + '[label=' .. l .. graphAttributes .. '];') + end + table.insert(str,'\n') + for i,edge in ipairs(edges) do + table.insert(str,nodelabels[edge.from] .. ' -> ' .. nodelabels[edge.to] .. ';\n') + end + table.insert(str,'}') + return table.concat(str,'') end - diff --git a/test/test_graph.lua b/test/test_graph.lua index c5bcf25..3ae9e16 100644 --- a/test/test_graph.lua +++ b/test/test_graph.lua @@ -6,132 +6,132 @@ local tester = totem.Tester() local tests = {} local function create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) - local g = graph.Graph() - local conmat = torch.rand(nlayers, nhiddens, nhiddens):ge(droprate)[{ {1, -2}, {}, {} }] - - -- create nodes - local nodes = { [0] = {}, [nlayers+1] = {} } - local nodecntr = 1 - for inode = 1, ninputs do - local node = graph.Node(nodecntr) - nodes[0][inode] = node - nodecntr = nodecntr + 1 - end - for ilayer = 1, nlayers do - nodes[ilayer] = {} - for inode = 1, nhiddens do - local node = graph.Node(nodecntr) - nodes[ilayer][inode] = node - nodecntr = nodecntr + 1 - end - end - for inode = 1, noutputs do - local node = graph.Node(nodecntr) - nodes[nlayers+1][inode] = node - nodecntr = nodecntr + 1 - end - - -- now connect inputs to all first layer hiddens - for iinput = 1, ninputs do - for inode = 1, nhiddens do - g:add(graph.Edge(nodes[0][iinput], nodes[1][inode])) - end - end - -- now run through layers and connect them - for ilayer = 1, nlayers-1 do - for jnode = 1, nhiddens do - for knode = 1, nhiddens do - if conmat[ilayer][jnode][knode] == 1 then - g:add(graph.Edge(nodes[ilayer][jnode], nodes[ilayer+1][knode])) - end - end - end - end - -- now connect last layer hiddens to outputs - for inode = 1, nhiddens do - for ioutput = 1, noutputs do - g:add(graph.Edge(nodes[nlayers][inode], nodes[nlayers+1][ioutput])) - end - end - - -- there might be nodes left out and not connected to anything. Connect them - for i = 1, nlayers do - for j = 1, nhiddens do - if not g.nodes[nodes[i][j]] then - local jto = torch.random(1, nhiddens) - g:add(graph.Edge(nodes[i][j], nodes[i+1][jto])) - conmat[i][j][jto] = 1 - end - end - end - - return g, conmat + local g = graph.Graph() + local conmat = torch.rand(nlayers, nhiddens, nhiddens):ge(droprate)[{ {1, -2}, {}, {} }] + + -- create nodes + local nodes = { [0] = {}, [nlayers+1] = {} } + local nodecntr = 1 + for inode = 1, ninputs do + local node = graph.Node(nodecntr) + nodes[0][inode] = node + nodecntr = nodecntr + 1 + end + for ilayer = 1, nlayers do + nodes[ilayer] = {} + for inode = 1, nhiddens do + local node = graph.Node(nodecntr) + nodes[ilayer][inode] = node + nodecntr = nodecntr + 1 + end + end + for inode = 1, noutputs do + local node = graph.Node(nodecntr) + nodes[nlayers+1][inode] = node + nodecntr = nodecntr + 1 + end + + -- now connect inputs to all first layer hiddens + for iinput = 1, ninputs do + for inode = 1, nhiddens do + g:add(graph.Edge(nodes[0][iinput], nodes[1][inode])) + end + end + -- now run through layers and connect them + for ilayer = 1, nlayers-1 do + for jnode = 1, nhiddens do + for knode = 1, nhiddens do + if conmat[ilayer][jnode][knode] == 1 then + g:add(graph.Edge(nodes[ilayer][jnode], nodes[ilayer+1][knode])) + end + end + end + end + -- now connect last layer hiddens to outputs + for inode = 1, nhiddens do + for ioutput = 1, noutputs do + g:add(graph.Edge(nodes[nlayers][inode], nodes[nlayers+1][ioutput])) + end + end + + -- there might be nodes left out and not connected to anything. Connect them + for i = 1, nlayers do + for j = 1, nhiddens do + if not g.nodes[nodes[i][j]] then + local jto = torch.random(1, nhiddens) + g:add(graph.Edge(nodes[i][j], nodes[i+1][jto])) + conmat[i][j][jto] = 1 + end + end + end + + return g, conmat end function tests.graph() - local nlayers = torch.random(2,5) - local ninputs = torch.random(1,10) - local noutputs = torch.random(1,10) - local nhiddens = torch.random(10,20) - local droprates = {0, torch.uniform(0.2, 0.8), 1} - for i, droprate in ipairs(droprates) do - local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) - - local nedges = nhiddens * (ninputs+noutputs) + c:sum() - local nnodes = ninputs + noutputs + nhiddens*nlayers - local nroots = ninputs + c:sum(2):eq(0):sum() - local nleaves = noutputs + c:sum(3):eq(0):sum() - - tester:asserteq(#g.edges, nedges, 'wrong number of edges') - tester:asserteq(#g.nodes, nnodes, 'wrong number of nodes') - tester:asserteq(#g:roots(), nroots, 'wrong number of roots') - tester:asserteq(#g:leaves(), nleaves, 'wrong number of leaves') - end + local nlayers = torch.random(2,5) + local ninputs = torch.random(1,10) + local noutputs = torch.random(1,10) + local nhiddens = torch.random(10,20) + local droprates = {0, torch.uniform(0.2, 0.8), 1} + for i, droprate in ipairs(droprates) do + local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) + + local nedges = nhiddens * (ninputs+noutputs) + c:sum() + local nnodes = ninputs + noutputs + nhiddens*nlayers + local nroots = ninputs + c:sum(2):eq(0):sum() + local nleaves = noutputs + c:sum(3):eq(0):sum() + + tester:asserteq(#g.edges, nedges, 'wrong number of edges') + tester:asserteq(#g.nodes, nnodes, 'wrong number of nodes') + tester:asserteq(#g:roots(), nroots, 'wrong number of roots') + tester:asserteq(#g:leaves(), nleaves, 'wrong number of leaves') + end end function tests.test_dfs() - local nlayers = torch.random(5,10) - local ninputs = 1 - local noutputs = 1 - local nhiddens = 1 - local droprate = 0 + local nlayers = torch.random(5,10) + local ninputs = 1 + local noutputs = 1 + local nhiddens = 1 + local droprate = 0 - local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) - local roots = g:roots() - local leaves = g:leaves() + local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) + local roots = g:roots() + local leaves = g:leaves() - tester:asserteq(#roots, 1, 'expected a single root') - tester:asserteq(#leaves, 1, 'expected a single leaf') + tester:asserteq(#roots, 1, 'expected a single root') + tester:asserteq(#leaves, 1, 'expected a single leaf') - local dfs_nodes = {} - roots[1]:dfs(function(node) table.insert(dfs_nodes, node) end) + local dfs_nodes = {} + roots[1]:dfs(function(node) table.insert(dfs_nodes, node) end) - for i, node in ipairs(dfs_nodes) do - tester:asserteq(node.data, #dfs_nodes - i +1, 'dfs order wrong') - end + for i, node in ipairs(dfs_nodes) do + tester:asserteq(node.data, #dfs_nodes - i +1, 'dfs order wrong') + end end function tests.test_bfs() - local nlayers = torch.random(5,10) - local ninputs = 1 - local noutputs = 1 - local nhiddens = 1 - local droprate = 0 + local nlayers = torch.random(5,10) + local ninputs = 1 + local noutputs = 1 + local nhiddens = 1 + local droprate = 0 - local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) - local roots = g:roots() - local leaves = g:leaves() + local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate) + local roots = g:roots() + local leaves = g:leaves() - tester:asserteq(#roots, 1, 'expected a single root') - tester:asserteq(#leaves, 1, 'expected a single leaf') + tester:asserteq(#roots, 1, 'expected a single root') + tester:asserteq(#leaves, 1, 'expected a single leaf') - local bfs_nodes = {} - roots[1]:bfs(function(node) table.insert(bfs_nodes, node) end) + local bfs_nodes = {} + roots[1]:bfs(function(node) table.insert(bfs_nodes, node) end) - for i, node in ipairs(bfs_nodes) do - tester:asserteq(node.data, i, 'bfs order wrong') - end + for i, node in ipairs(bfs_nodes) do + tester:asserteq(node.data, i, 'bfs order wrong') + end end return tester:add(tests):run() diff --git a/test/test_graphviz.lua b/test/test_graphviz.lua index f0f15b2..1d344a9 100644 --- a/test/test_graphviz.lua +++ b/test/test_graphviz.lua @@ -5,33 +5,33 @@ local tester = totem.Tester() local tests = {} function tests.layout() - local g = graph.Graph() - local root = graph.Node(10) - local n1 = graph.Node(1) - local n2 = graph.Node(2) - g:add(graph.Edge(root, n1)) - g:add(graph.Edge(n1, n2)) + local g = graph.Graph() + local root = graph.Node(10) + local n1 = graph.Node(1) + local n2 = graph.Node(2) + g:add(graph.Edge(root, n1)) + g:add(graph.Edge(n1, n2)) - local positions = graph.graphvizLayout(g, 'dot') - local xs = positions:select(2, 1) - local ys = positions:select(2, 2) - tester:assertlt(xs:add(-xs:mean()):norm(), 1e-3, - "x coordinates should be the same") - tester:assertTensorEq(ys, torch.sort(ys, true), 1e-3, - "y coordinates should be ordered") + local positions = graph.graphvizLayout(g, 'dot') + local xs = positions:select(2, 1) + local ys = positions:select(2, 2) + tester:assertlt(xs:add(-xs:mean()):norm(), 1e-3, + "x coordinates should be the same") + tester:assertTensorEq(ys, torch.sort(ys, true), 1e-3, + "y coordinates should be ordered") end function tests.testDotEscape() - tester:assert(graph._dotEscape('red') == 'red', 'Don\'t escape single words') - tester:assert(graph._dotEscape('My label') == '"My label"', - 'Use quotes for spaces') - tester:assert(graph._dotEscape('Non[an') == '"Non[an"', - 'Use quotes for non-alpha characters') - tester:assert(graph._dotEscape('My\nnewline') == '"My\\nnewline"', - 'Escape newlines') - tester:assert(graph._dotEscape('Say "hello"') == '"Say \\"hello\\""', - 'Escape quotes') + tester:assert(graph._dotEscape('red') == 'red', 'Don\'t escape single words') + tester:assert(graph._dotEscape('My label') == '"My label"', + 'Use quotes for spaces') + tester:assert(graph._dotEscape('Non[an') == '"Non[an"', + 'Use quotes for non-alpha characters') + tester:assert(graph._dotEscape('My\nnewline') == '"My\\nnewline"', + 'Escape newlines') + tester:assert(graph._dotEscape('Say "hello"') == '"Say \\"hello\\""', + 'Escape quotes') end -- cgit v1.2.3