Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/graph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-09-05 00:14:07 +0300
committerSoumith Chintala <soumith@gmail.com>2015-09-05 00:14:07 +0300
commit8a371e617dc3505f5891d5146dde89ebca8877e1 (patch)
tree596d48139516d32a63077557f54e9da9a541b065
parentd545873580192eda1c664436f4da5fee194fa644 (diff)
parenteb572b1af822afcbf528b7719ca4fc9aa9e69c37 (diff)
Merge pull request #18 from clementfarabet/master
Whitespace cleanup.
-rw-r--r--Edge.lua8
-rw-r--r--Node.lua144
-rw-r--r--graphviz.lua250
-rw-r--r--init.lua343
-rw-r--r--test/test_graph.lua216
-rw-r--r--test/test_graphviz.lua44
6 files changed, 502 insertions, 503 deletions
diff --git a/Edge.lua b/Edge.lua
index f277613..1ffd670 100644
--- a/Edge.lua
+++ b/Edge.lua
@@ -1,10 +1,10 @@
--[[
- A Directed Edge class
- No methods, just two fields, from and to.
+A Directed Edge class
+No methods, just two fields, from and to.
]]--
local Edge = torch.class('graph.Edge')
function Edge:__init(from,to)
- self.from = from
- self.to = to
+ self.from = from
+ self.to = to
end
diff --git a/Node.lua b/Node.lua
index 459f9a8..1f83c18 100644
--- a/Node.lua
+++ b/Node.lua
@@ -1,105 +1,105 @@
--[[
- Node class. This class is generally used with edge to add edges into a graph.
- graph:add(graph.Edge(graph.Node(),graph.Node()))
+Node class. This class is generally used with edge to add edges into a graph.
+graph:add(graph.Edge(graph.Node(),graph.Node()))
- But, one can also easily use this node class to create a graph. It will register
- all the edges into its children table and one can parse the graph from any given node.
- The drawback is there will be no global edge table and node table, which is mostly useful
- to run algorithms on graphs. If all you need is just a data structure to store data and
- run DFS, BFS over the graph, then this method is also quick and nice.
+But, one can also easily use this node class to create a graph. It will register
+all the edges into its children table and one can parse the graph from any given node.
+The drawback is there will be no global edge table and node table, which is mostly useful
+to run algorithms on graphs. If all you need is just a data structure to store data and
+run DFS, BFS over the graph, then this method is also quick and nice.
--]]
local Node = torch.class('graph.Node')
function Node:__init(d,p)
- self.data = d
- self.id = 0
- self.children = {}
- self.visited = false
- self.marked = false
+ self.data = d
+ self.id = 0
+ self.children = {}
+ self.visited = false
+ self.marked = false
end
function Node:add(child)
- local children = self.children
- if type(child) == 'table' and not torch.typename(child) then
- for i,v in ipairs(child) do
- self:add(v)
- end
- elseif not children[child] then
- table.insert(children,child)
- children[child] = #children
- end
+ local children = self.children
+ if type(child) == 'table' and not torch.typename(child) then
+ for i,v in ipairs(child) do
+ self:add(v)
+ end
+ elseif not children[child] then
+ table.insert(children,child)
+ children[child] = #children
+ end
end
-- visitor
function Node:visit(pre_func,post_func)
- if not self.visited then
- if pre_func then pre_func(self) end
- for i,child in ipairs(self.children) do
- child:visit(pre_func, post_func)
- end
- if post_func then post_func(self) end
- end
+ if not self.visited then
+ if pre_func then pre_func(self) end
+ for i,child in ipairs(self.children) do
+ child:visit(pre_func, post_func)
+ end
+ if post_func then post_func(self) end
+ end
end
function Node:label()
- return tostring(self.data)
+ return tostring(self.data)
end
-- Create a graph from the Node traversal
function Node:graph()
- local g = graph.Graph()
- local function build_graph(node)
- for i,child in ipairs(node.children) do
- g:add(graph.Edge(node,child))
- end
- end
- self:bfs(build_graph)
- return g
+ local g = graph.Graph()
+ local function build_graph(node)
+ for i,child in ipairs(node.children) do
+ g:add(graph.Edge(node,child))
+ end
+ end
+ self:bfs(build_graph)
+ return g
end
function Node:dfs_dirty(func)
- local visitednodes = {}
- local dfs_func = function(node)
- func(node)
- table.insert(visitednodes,node)
- end
- local dfs_func_pre = function(node)
- node.visited = true
- end
- self:visit(dfs_func_pre, dfs_func)
- return visitednodes
+ local visitednodes = {}
+ local dfs_func = function(node)
+ func(node)
+ table.insert(visitednodes,node)
+ end
+ local dfs_func_pre = function(node)
+ node.visited = true
+ end
+ self:visit(dfs_func_pre, dfs_func)
+ return visitednodes
end
function Node:dfs(func)
- for i,node in ipairs(self:dfs_dirty(func)) do
- node.visited = false
- end
+ for i,node in ipairs(self:dfs_dirty(func)) do
+ node.visited = false
+ end
end
function Node:bfs_dirty(func)
- local visitednodes = {}
- local bfsnodes = {}
- local bfs_func = function(node)
- func(node)
- for i,child in ipairs(node.children) do
- if not child.marked then
- child.marked = true
- table.insert(bfsnodes,child)
- end
- end
- end
- table.insert(bfsnodes,self)
- self.marked = true
- while #bfsnodes > 0 do
- local node = table.remove(bfsnodes,1)
- table.insert(visitednodes,node)
- bfs_func(node)
- end
- return visitednodes
+ local visitednodes = {}
+ local bfsnodes = {}
+ local bfs_func = function(node)
+ func(node)
+ for i,child in ipairs(node.children) do
+ if not child.marked then
+ child.marked = true
+ table.insert(bfsnodes,child)
+ end
+ end
+ end
+ table.insert(bfsnodes,self)
+ self.marked = true
+ while #bfsnodes > 0 do
+ local node = table.remove(bfsnodes,1)
+ table.insert(visitednodes,node)
+ bfs_func(node)
+ end
+ return visitednodes
end
function Node:bfs(func)
- for i,node in ipairs(self:bfs_dirty(func)) do
- node.marked = false
- end
+ for i,node in ipairs(self:bfs_dirty(func)) do
+ node.marked = false
+ end
end
diff --git a/graphviz.lua b/graphviz.lua
index 785e325..bff189e 100644
--- a/graphviz.lua
+++ b/graphviz.lua
@@ -9,90 +9,90 @@ local cgraph
ffiOk, ffi = pcall(require, 'ffi')
if ffiOk then
- ffi.cdef[[
-typedef struct FILE FILE;
-
-typedef struct Agraph_s Agraph_t;
-typedef struct Agnode_s Agnode_t;
-
-extern Agraph_t *agmemread(const char *cp);
-extern char *agget(void *obj, char *name);
-extern int agclose(Agraph_t * g);
-extern Agnode_t *agfstnode(Agraph_t * g);
-extern Agnode_t *agnxtnode(Agraph_t * g, Agnode_t * n);
-extern Agnode_t *aglstnode(Agraph_t * g);
-extern Agnode_t *agprvnode(Agraph_t * g, Agnode_t * n);
-
-typedef struct Agraph_s graph_t;
-typedef struct GVJ_s GVJ_t;
-typedef struct GVG_s GVG_t;
-typedef struct GVC_s GVC_t;
-extern GVC_t *gvContext(void);
-extern int gvLayout(GVC_t *context, graph_t *g, const char *engine);
-extern int gvRender(GVC_t *context, graph_t *g, const char *format, FILE *out);
-extern int gvFreeLayout(GVC_t *context, graph_t *g);
-extern int gvFreeContext(GVC_t *context);
-]]
- graphvizOk, graphviz = pcall(function() return ffi.load('libgvc') end)
- if not graphvizOk then
+ ffi.cdef[[
+ typedef struct FILE FILE;
+
+ typedef struct Agraph_s Agraph_t;
+ typedef struct Agnode_s Agnode_t;
+
+ extern Agraph_t *agmemread(const char *cp);
+ extern char *agget(void *obj, char *name);
+ extern int agclose(Agraph_t * g);
+ extern Agnode_t *agfstnode(Agraph_t * g);
+ extern Agnode_t *agnxtnode(Agraph_t * g, Agnode_t * n);
+ extern Agnode_t *aglstnode(Agraph_t * g);
+ extern Agnode_t *agprvnode(Agraph_t * g, Agnode_t * n);
+
+ typedef struct Agraph_s graph_t;
+ typedef struct GVJ_s GVJ_t;
+ typedef struct GVG_s GVG_t;
+ typedef struct GVC_s GVC_t;
+ extern GVC_t *gvContext(void);
+ extern int gvLayout(GVC_t *context, graph_t *g, const char *engine);
+ extern int gvRender(GVC_t *context, graph_t *g, const char *format, FILE *out);
+ extern int gvFreeLayout(GVC_t *context, graph_t *g);
+ extern int gvFreeContext(GVC_t *context);
+ ]]
+ graphvizOk, graphviz = pcall(function() return ffi.load('libgvc') end)
+ if not graphvizOk then
graphvizOk, graphviz = pcall(function() return ffi.load('libgvc.so.6') end)
- end
+ end
- cgraphOk, cgraph = pcall(function() return ffi.load('libcgraph') end)
- if not cgraphOk then
+ cgraphOk, cgraph = pcall(function() return ffi.load('libcgraph') end)
+ if not cgraphOk then
cgraphOk, cgraph = pcall(function() return ffi.load('libcgraph.so.6') end)
- end
+ end
else
- graphvizOk = false
- cgraphOk = false
+ graphvizOk = false
+ cgraphOk = false
end
-- Retrieve attribute data from a graphviz object.
local function getAttribute(obj, name)
- local res = cgraph.agget(obj, ffi.cast("char*", name))
- assert(res ~= ffi.cast("char*", nil), 'could not get attr ' .. name)
- return ffi.string(res)
+ local res = cgraph.agget(obj, ffi.cast("char*", name))
+ assert(res ~= ffi.cast("char*", nil), 'could not get attr ' .. name)
+ return ffi.string(res)
end
-- Iterate through nodes of a graphviz graph.
local function nodeIterator(graph)
- local node = cgraph.agfstnode(graph)
- local nextNode
- return function()
- if node == nil then return end
- if node == cgraph.aglstnode(graph) then nextNode = nil end
- nextNode = cgraph.agnxtnode(graph, node)
- local result = node
- node = nextNode
- return result
- end
+ local node = cgraph.agfstnode(graph)
+ local nextNode
+ return function()
+ if node == nil then return end
+ if node == cgraph.aglstnode(graph) then nextNode = nil end
+ nextNode = cgraph.agnxtnode(graph, node)
+ local result = node
+ node = nextNode
+ return result
+ end
end
-- Convert a string of comma-separated numbers to actual numbers.
local function extractNumbers(n, attr)
- local res = {}
- for number in string.gmatch(attr, "[^%,]+") do
- table.insert(res, tonumber(number))
- end
- assert(#res == n, "attribute is not of expected form")
- return unpack(res)
+ local res = {}
+ for number in string.gmatch(attr, "[^%,]+") do
+ table.insert(res, tonumber(number))
+ end
+ assert(#res == n, "attribute is not of expected form")
+ return unpack(res)
end
-- Transform from graphviz coordinates to unit square.
local function getRelativePosition(node, bbox)
- local x0, y0, w, h = unpack(bbox)
- local x, y = extractNumbers(2, getAttribute(node, 'pos'))
- local xt = (x - x0) / w
- local yt = (y - y0) / h
- assert(xt >= 0 and xt <= 1, "bad x coordinate")
- assert(yt >= 0 and yt <= 1, "bad y coordinate")
- return xt, yt
+ local x0, y0, w, h = unpack(bbox)
+ local x, y = extractNumbers(2, getAttribute(node, 'pos'))
+ local xt = (x - x0) / w
+ local yt = (y - y0) / h
+ assert(xt >= 0 and xt <= 1, "bad x coordinate")
+ assert(yt >= 0 and yt <= 1, "bad y coordinate")
+ return xt, yt
end
-- Retrieve a node's ID based on its label string.
local function getID(node)
- local label = getAttribute(node, 'label')
- local res = {string.find(label, "^Node(%d+)")} or {string.find(label, "%((%d+)%)\\n")}
- local id = res[3]
- assert(id ~= nil, "could not get ID from node label : <" .. tostring(label) .. ">")
- return tonumber(id)
+ local label = getAttribute(node, 'label')
+ local res = {string.find(label, "^Node(%d+)")} or {string.find(label, "%((%d+)%)\\n")}
+ local id = res[3]
+ assert(id ~= nil, "could not get ID from node label : <" .. tostring(label) .. ">")
+ return tonumber(id)
end
--[[ Lay out a graph and return the positions of the nodes.
@@ -109,55 +109,55 @@ Coordinates are in the interval [0, 1].
]]
function graph.graphvizLayout(g, algorithm)
- if not graphvizOk or not cgraphOk then
- error("graphviz library could not be loaded.")
- end
- local nNodes = #g.nodes
- local context = graphviz.gvContext()
- local graphvizGraph = cgraph.agmemread(g:todot())
- local algorithm = algorithm or "dot"
- assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm),
- "graphviz layout failed")
- assert(0 == graphviz.gvRender(context, graphvizGraph, algorithm, nil),
- "graphviz render failed")
-
- -- Extract bounding box.
- local x0, y0, x1, y1 = extractNumbers(4,
- getAttribute(graphvizGraph, 'bb'), ",")
- local w = x1 - x0
- local h = y1 - y0
- local bbox = { x0, y0, w, h }
-
- -- Extract node positions.
- local positions = torch.zeros(nNodes, 2)
- for node in nodeIterator(graphvizGraph) do
- local id = getID(node)
- local x, y = getRelativePosition(node, bbox)
- positions[id][1] = x
- positions[id][2] = y
- end
-
- -- Clean up.
- graphviz.gvFreeLayout(context, graphvizGraph)
- cgraph.agclose(graphvizGraph)
- graphviz.gvFreeContext(context)
- return positions
+ if not graphvizOk or not cgraphOk then
+ error("graphviz library could not be loaded.")
+ end
+ local nNodes = #g.nodes
+ local context = graphviz.gvContext()
+ local graphvizGraph = cgraph.agmemread(g:todot())
+ local algorithm = algorithm or "dot"
+ assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm),
+ "graphviz layout failed")
+ assert(0 == graphviz.gvRender(context, graphvizGraph, algorithm, nil),
+ "graphviz render failed")
+
+ -- Extract bounding box.
+ local x0, y0, x1, y1 = extractNumbers(4,
+ getAttribute(graphvizGraph, 'bb'), ",")
+ local w = x1 - x0
+ local h = y1 - y0
+ local bbox = { x0, y0, w, h }
+
+ -- Extract node positions.
+ local positions = torch.zeros(nNodes, 2)
+ for node in nodeIterator(graphvizGraph) do
+ local id = getID(node)
+ local x, y = getRelativePosition(node, bbox)
+ positions[id][1] = x
+ positions[id][2] = y
+ end
+
+ -- Clean up.
+ graphviz.gvFreeLayout(context, graphvizGraph)
+ cgraph.agclose(graphvizGraph)
+ graphviz.gvFreeContext(context)
+ return positions
end
function graph.graphvizFile(g, algorithm, fname)
- algorithm = algorithm or 'dot'
- local _,_,rendertype = fname:reverse():find('(%a+)%.%w+')
- rendertype = rendertype:reverse()
-
- local context = graphviz.gvContext()
- local graphvizGraph = cgraph.agmemread(g:todot())
- assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm),
- "graphviz layout failed")
- assert(0 == graphviz.gvRender(context, graphvizGraph, rendertype, io.open(fname, 'w')),
- "graphviz render failed")
- graphviz.gvFreeLayout(context, graphvizGraph)
- cgraph.agclose(graphvizGraph)
- graphviz.gvFreeContext(context)
+ algorithm = algorithm or 'dot'
+ local _,_,rendertype = fname:reverse():find('(%a+)%.%w+')
+ rendertype = rendertype:reverse()
+
+ local context = graphviz.gvContext()
+ local graphvizGraph = cgraph.agmemread(g:todot())
+ assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm),
+ "graphviz layout failed")
+ assert(0 == graphviz.gvRender(context, graphvizGraph, rendertype, io.open(fname, 'w')),
+ "graphviz render failed")
+ graphviz.gvFreeLayout(context, graphvizGraph)
+ cgraph.agclose(graphvizGraph)
+ graphviz.gvFreeContext(context)
end
--[[
@@ -167,25 +167,25 @@ Args:
* `g` - graph to display
* `title` - Title to display in the graph
* `fname` - [optional] if given it should contain a file name without an extension,
- the graph is saved on disk as fname.svg and display is not shown. If not given
- the graph is shown on qt display (you need to have qtsvg installed and running qlua)
+the graph is saved on disk as fname.svg and display is not shown. If not given
+the graph is shown on qt display (you need to have qtsvg installed and running qlua)
Returns:
* `qs` - the window handle for the qt display (if fname given) or nil
]]
function graph.dot(g,title,fname)
- local qt_display = fname == nil
- fname = fname or os.tmpname()
- local fnsvg = fname .. '.svg'
- local fndot = fname .. '.dot'
- graph.graphvizFile(g, 'dot', fnsvg)
- graph.graphvizFile(g, 'dot', fndot)
- if qt_display then
- require 'qtsvg'
- local qs = qt.QSvgWidget(fnsvg)
- qs:show()
- os.remove(fnsvg)
- os.remove(fndot)
- return qs
- end
+ local qt_display = fname == nil
+ fname = fname or os.tmpname()
+ local fnsvg = fname .. '.svg'
+ local fndot = fname .. '.dot'
+ graph.graphvizFile(g, 'dot', fnsvg)
+ graph.graphvizFile(g, 'dot', fndot)
+ if qt_display then
+ require 'qtsvg'
+ local qs = qt.QSvgWidget(fnsvg)
+ qs:show()
+ os.remove(fnsvg)
+ os.remove(fndot)
+ return qs
+ end
end
diff --git a/init.lua b/init.lua
index 3cf6e40..f3f25dd 100644
--- a/init.lua
+++ b/init.lua
@@ -8,227 +8,226 @@ torch.include('graph','Edge.lua')
--[[
- Defines a graph and general operations on grpahs like topsort,
- connected components, ...
- uses two tables, one for nodes, one for edges
+Defines a graph and general operations on grpahs like topsort,
+connected components, ...
+uses two tables, one for nodes, one for edges
]]--
local Graph = torch.class('graph.Graph')
function Graph:__init()
- self.nodes = {}
- self.edges = {}
+ self.nodes = {}
+ self.edges = {}
end
-- add a new edge into the graph.
-- an edge has two fields, from and to that are inserted into the
-- nodes table. the edge itself is inserted into the edges table.
function Graph:add(edge)
- if type(edge) ~= 'table' then
- error('graph.Edge or {graph.Edges} expected')
- end
- if torch.typename(edge) then
- -- add edge
- if not self.edges[edge] then
- table.insert(self.edges,edge)
- self.edges[edge] = #self.edges
- end
- -- add from node
- if not self.nodes[edge.from] then
- table.insert(self.nodes,edge.from)
- self.nodes[edge.from] = #self.nodes
- end
- -- add to node
- if not self.nodes[edge.to] then
- table.insert(self.nodes,edge.to)
- self.nodes[edge.to] = #self.nodes
- end
- -- add the edge to the node for parsing in nodes
- edge.from:add(edge.to)
- edge.from.id = self.nodes[edge.from]
- edge.to.id = self.nodes[edge.to]
- else
- for i,e in ipairs(edge) do
- self:add(e)
- end
- end
+ if type(edge) ~= 'table' then
+ error('graph.Edge or {graph.Edges} expected')
+ end
+ if torch.typename(edge) then
+ -- add edge
+ if not self.edges[edge] then
+ table.insert(self.edges,edge)
+ self.edges[edge] = #self.edges
+ end
+ -- add from node
+ if not self.nodes[edge.from] then
+ table.insert(self.nodes,edge.from)
+ self.nodes[edge.from] = #self.nodes
+ end
+ -- add to node
+ if not self.nodes[edge.to] then
+ table.insert(self.nodes,edge.to)
+ self.nodes[edge.to] = #self.nodes
+ end
+ -- add the edge to the node for parsing in nodes
+ edge.from:add(edge.to)
+ edge.from.id = self.nodes[edge.from]
+ edge.to.id = self.nodes[edge.to]
+ else
+ for i,e in ipairs(edge) do
+ self:add(e)
+ end
+ end
end
-- Clone a Graph
-- this will create new nodes, but will share the data.
-- Note that primitive data types like numbers can not be shared
function Graph:clone()
- local clone = graph.Graph()
- local nodes = {}
- for i,n in ipairs(self.nodes) do
- table.insert(nodes,n.new(n.data))
- end
- for i,e in ipairs(self.edges) do
- local from = nodes[self.nodes[e.from]]
- local to = nodes[self.nodes[e.to]]
- clone:add(e.new(from,to))
- end
- return clone
+ local clone = graph.Graph()
+ local nodes = {}
+ for i,n in ipairs(self.nodes) do
+ table.insert(nodes,n.new(n.data))
+ end
+ for i,e in ipairs(self.edges) do
+ local from = nodes[self.nodes[e.from]]
+ local to = nodes[self.nodes[e.to]]
+ clone:add(e.new(from,to))
+ end
+ return clone
end
-- It returns a new graph where the edges are reversed.
-- The nodes share the data. Note that primitive data types can
-- not be shared.
function Graph:reverse()
- local rg = graph.Graph()
- local mapnodes = {}
- for i,e in ipairs(self.edges) do
- mapnodes[e.from] = mapnodes[e.from] or e.from.new(e.from.data)
- mapnodes[e.to] = mapnodes[e.to] or e.to.new(e.to.data)
- local from = mapnodes[e.from]
- local to = mapnodes[e.to]
- rg:add(e.new(to,from))
- end
- return rg,mapnodes
+ local rg = graph.Graph()
+ local mapnodes = {}
+ for i,e in ipairs(self.edges) do
+ mapnodes[e.from] = mapnodes[e.from] or e.from.new(e.from.data)
+ mapnodes[e.to] = mapnodes[e.to] or e.to.new(e.to.data)
+ local from = mapnodes[e.from]
+ local to = mapnodes[e.to]
+ rg:add(e.new(to,from))
+ end
+ return rg,mapnodes
end
--[[
- Topological Sort
- ** This is not finished. OK for graphs with single root.
+Topological Sort
+** This is not finished. OK for graphs with single root.
]]--
function Graph:topsort()
- -- reverse the graph
- local rg,map = self:reverse()
- local rmap = {}
- for k,v in pairs(map) do
- rmap[v] = k
- end
-
- -- work on the sorted graph
- local sortednodes = {}
- local rootnodes = rg:roots()
-
- if #rootnodes == 0 then
- error('Graph has cycles')
- end
-
- -- run
- for i,root in ipairs(rootnodes) do
- root:dfs(function(node) table.insert(sortednodes,rmap[node]) end)
- end
-
- if #sortednodes ~= #self.nodes then
- error('Graph has cycles')
- end
- return sortednodes,rg,rootnodes
+ -- reverse the graph
+ local rg,map = self:reverse()
+ local rmap = {}
+ for k,v in pairs(map) do
+ rmap[v] = k
+ end
+
+ -- work on the sorted graph
+ local sortednodes = {}
+ local rootnodes = rg:roots()
+
+ if #rootnodes == 0 then
+ error('Graph has cycles')
+ end
+
+ -- run
+ for i,root in ipairs(rootnodes) do
+ root:dfs(function(node) table.insert(sortednodes,rmap[node]) end)
+ end
+
+ if #sortednodes ~= #self.nodes then
+ error('Graph has cycles')
+ end
+ return sortednodes,rg,rootnodes
end
-- find root nodes
function Graph:roots()
- local edges = self.edges
- local rootnodes = {}
- for i,edge in ipairs(edges) do
- --table.insert(rootnodes,edge.from)
- if not rootnodes[edge.from] then
- rootnodes[edge.from] = #rootnodes+1
- end
- end
- for i,edge in ipairs(edges) do
- if rootnodes[edge.to] then
- rootnodes[edge.to] = nil
- end
- end
- local roots = {}
- for root,i in pairs(rootnodes) do
- table.insert(roots, root)
- end
- table.sort(roots,function(a,b) return self.nodes[a] < self.nodes[b] end )
- return roots
+ local edges = self.edges
+ local rootnodes = {}
+ for i,edge in ipairs(edges) do
+ --table.insert(rootnodes,edge.from)
+ if not rootnodes[edge.from] then
+ rootnodes[edge.from] = #rootnodes+1
+ end
+ end
+ for i,edge in ipairs(edges) do
+ if rootnodes[edge.to] then
+ rootnodes[edge.to] = nil
+ end
+ end
+ local roots = {}
+ for root,i in pairs(rootnodes) do
+ table.insert(roots, root)
+ end
+ table.sort(roots,function(a,b) return self.nodes[a] < self.nodes[b] end )
+ return roots
end
-- find root nodes
function Graph:leaves()
- local edges = self.edges
- local leafnodes = {}
- for i,edge in ipairs(edges) do
- --table.insert(rootnodes,edge.from)
- if not leafnodes[edge.to] then
- leafnodes[edge.to] = #leafnodes+1
- end
- end
- for i,edge in ipairs(edges) do
- if leafnodes[edge.from] then
- leafnodes[edge.from] = nil
- end
- end
- local leaves = {}
- for leaf,i in pairs(leafnodes) do
- table.insert(leaves, leaf)
- end
- table.sort(leaves,function(a,b) return self.nodes[a] < self.nodes[b] end )
- return leaves
+ local edges = self.edges
+ local leafnodes = {}
+ for i,edge in ipairs(edges) do
+ --table.insert(rootnodes,edge.from)
+ if not leafnodes[edge.to] then
+ leafnodes[edge.to] = #leafnodes+1
+ end
+ end
+ for i,edge in ipairs(edges) do
+ if leafnodes[edge.from] then
+ leafnodes[edge.from] = nil
+ end
+ end
+ local leaves = {}
+ for leaf,i in pairs(leafnodes) do
+ table.insert(leaves, leaf)
+ end
+ table.sort(leaves,function(a,b) return self.nodes[a] < self.nodes[b] end )
+ return leaves
end
function graph._dotEscape(str)
- if string.find(str, '[^a-zA-Z]') then
- -- Escape newlines and quotes.
- local escaped = string.gsub(str, '\n', '\\n')
- escaped = string.gsub(escaped, '"', '\\"')
- str = '"' .. escaped .. '"'
- end
- return str
+ if string.find(str, '[^a-zA-Z]') then
+ -- Escape newlines and quotes.
+ local escaped = string.gsub(str, '\n', '\\n')
+ escaped = string.gsub(escaped, '"', '\\"')
+ str = '"' .. escaped .. '"'
+ end
+ return str
end
--[[ Generate a string like 'color=blue tailport=s' from a table
- (e.g. {color = 'blue', tailport = 's'}. Its up to the user to escape
- strings properly.
+(e.g. {color = 'blue', tailport = 's'}. Its up to the user to escape
+strings properly.
]]
local function makeAttributeString(attributes)
- local str = {}
- local keys = {}
- for k, _ in pairs(attributes) do
- table.insert(keys, k)
- end
- table.sort(keys)
- for _, k in ipairs(keys) do
- local v = attributes[k]
- table.insert(str, tostring(k) .. '=' .. graph._dotEscape(tostring(v)))
- end
- return ' ' .. table.concat(str, ' ')
+ local str = {}
+ local keys = {}
+ for k, _ in pairs(attributes) do
+ table.insert(keys, k)
+ end
+ table.sort(keys)
+ for _, k in ipairs(keys) do
+ local v = attributes[k]
+ table.insert(str, tostring(k) .. '=' .. graph._dotEscape(tostring(v)))
+ end
+ return ' ' .. table.concat(str, ' ')
end
function Graph:todot(title)
- local nodes = self.nodes
- local edges = self.edges
- local str = {}
- table.insert(str,'digraph G {\n')
- if title then
- table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
- end
- table.insert(str,'node [shape = oval]; ')
- local nodelabels = {}
- for i,node in ipairs(nodes) do
- local nodeName
- if node.graphNodeName then
- nodeName = node:graphNodeName()
- else
- nodeName = 'Node' .. node.id
- end
- local l = graph._dotEscape(nodeName .. '\n' .. node:label())
- nodelabels[node] = 'n' .. node.id
- local graphAttributes = ''
- if node.graphNodeAttributes then
- graphAttributes = makeAttributeString(
- node:graphNodeAttributes())
- end
- table.insert(str,
- '\n' .. nodelabels[node] ..
- '[label=' .. l .. graphAttributes .. '];')
- end
- table.insert(str,'\n')
- for i,edge in ipairs(edges) do
- table.insert(str,nodelabels[edge.from] .. ' -> ' .. nodelabels[edge.to] .. ';\n')
- end
- table.insert(str,'}')
- return table.concat(str,'')
+ local nodes = self.nodes
+ local edges = self.edges
+ local str = {}
+ table.insert(str,'digraph G {\n')
+ if title then
+ table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
+ end
+ table.insert(str,'node [shape = oval]; ')
+ local nodelabels = {}
+ for i,node in ipairs(nodes) do
+ local nodeName
+ if node.graphNodeName then
+ nodeName = node:graphNodeName()
+ else
+ nodeName = 'Node' .. node.id
+ end
+ local l = graph._dotEscape(nodeName .. '\n' .. node:label())
+ nodelabels[node] = 'n' .. node.id
+ local graphAttributes = ''
+ if node.graphNodeAttributes then
+ graphAttributes = makeAttributeString(
+ node:graphNodeAttributes())
+ end
+ table.insert(str,
+ '\n' .. nodelabels[node] ..
+ '[label=' .. l .. graphAttributes .. '];')
+ end
+ table.insert(str,'\n')
+ for i,edge in ipairs(edges) do
+ table.insert(str,nodelabels[edge.from] .. ' -> ' .. nodelabels[edge.to] .. ';\n')
+ end
+ table.insert(str,'}')
+ return table.concat(str,'')
end
-
diff --git a/test/test_graph.lua b/test/test_graph.lua
index c5bcf25..3ae9e16 100644
--- a/test/test_graph.lua
+++ b/test/test_graph.lua
@@ -6,132 +6,132 @@ local tester = totem.Tester()
local tests = {}
local function create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
- local g = graph.Graph()
- local conmat = torch.rand(nlayers, nhiddens, nhiddens):ge(droprate)[{ {1, -2}, {}, {} }]
-
- -- create nodes
- local nodes = { [0] = {}, [nlayers+1] = {} }
- local nodecntr = 1
- for inode = 1, ninputs do
- local node = graph.Node(nodecntr)
- nodes[0][inode] = node
- nodecntr = nodecntr + 1
- end
- for ilayer = 1, nlayers do
- nodes[ilayer] = {}
- for inode = 1, nhiddens do
- local node = graph.Node(nodecntr)
- nodes[ilayer][inode] = node
- nodecntr = nodecntr + 1
- end
- end
- for inode = 1, noutputs do
- local node = graph.Node(nodecntr)
- nodes[nlayers+1][inode] = node
- nodecntr = nodecntr + 1
- end
-
- -- now connect inputs to all first layer hiddens
- for iinput = 1, ninputs do
- for inode = 1, nhiddens do
- g:add(graph.Edge(nodes[0][iinput], nodes[1][inode]))
- end
- end
- -- now run through layers and connect them
- for ilayer = 1, nlayers-1 do
- for jnode = 1, nhiddens do
- for knode = 1, nhiddens do
- if conmat[ilayer][jnode][knode] == 1 then
- g:add(graph.Edge(nodes[ilayer][jnode], nodes[ilayer+1][knode]))
- end
- end
- end
- end
- -- now connect last layer hiddens to outputs
- for inode = 1, nhiddens do
- for ioutput = 1, noutputs do
- g:add(graph.Edge(nodes[nlayers][inode], nodes[nlayers+1][ioutput]))
- end
- end
-
- -- there might be nodes left out and not connected to anything. Connect them
- for i = 1, nlayers do
- for j = 1, nhiddens do
- if not g.nodes[nodes[i][j]] then
- local jto = torch.random(1, nhiddens)
- g:add(graph.Edge(nodes[i][j], nodes[i+1][jto]))
- conmat[i][j][jto] = 1
- end
- end
- end
-
- return g, conmat
+ local g = graph.Graph()
+ local conmat = torch.rand(nlayers, nhiddens, nhiddens):ge(droprate)[{ {1, -2}, {}, {} }]
+
+ -- create nodes
+ local nodes = { [0] = {}, [nlayers+1] = {} }
+ local nodecntr = 1
+ for inode = 1, ninputs do
+ local node = graph.Node(nodecntr)
+ nodes[0][inode] = node
+ nodecntr = nodecntr + 1
+ end
+ for ilayer = 1, nlayers do
+ nodes[ilayer] = {}
+ for inode = 1, nhiddens do
+ local node = graph.Node(nodecntr)
+ nodes[ilayer][inode] = node
+ nodecntr = nodecntr + 1
+ end
+ end
+ for inode = 1, noutputs do
+ local node = graph.Node(nodecntr)
+ nodes[nlayers+1][inode] = node
+ nodecntr = nodecntr + 1
+ end
+
+ -- now connect inputs to all first layer hiddens
+ for iinput = 1, ninputs do
+ for inode = 1, nhiddens do
+ g:add(graph.Edge(nodes[0][iinput], nodes[1][inode]))
+ end
+ end
+ -- now run through layers and connect them
+ for ilayer = 1, nlayers-1 do
+ for jnode = 1, nhiddens do
+ for knode = 1, nhiddens do
+ if conmat[ilayer][jnode][knode] == 1 then
+ g:add(graph.Edge(nodes[ilayer][jnode], nodes[ilayer+1][knode]))
+ end
+ end
+ end
+ end
+ -- now connect last layer hiddens to outputs
+ for inode = 1, nhiddens do
+ for ioutput = 1, noutputs do
+ g:add(graph.Edge(nodes[nlayers][inode], nodes[nlayers+1][ioutput]))
+ end
+ end
+
+ -- there might be nodes left out and not connected to anything. Connect them
+ for i = 1, nlayers do
+ for j = 1, nhiddens do
+ if not g.nodes[nodes[i][j]] then
+ local jto = torch.random(1, nhiddens)
+ g:add(graph.Edge(nodes[i][j], nodes[i+1][jto]))
+ conmat[i][j][jto] = 1
+ end
+ end
+ end
+
+ return g, conmat
end
function tests.graph()
- local nlayers = torch.random(2,5)
- local ninputs = torch.random(1,10)
- local noutputs = torch.random(1,10)
- local nhiddens = torch.random(10,20)
- local droprates = {0, torch.uniform(0.2, 0.8), 1}
- for i, droprate in ipairs(droprates) do
- local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
-
- local nedges = nhiddens * (ninputs+noutputs) + c:sum()
- local nnodes = ninputs + noutputs + nhiddens*nlayers
- local nroots = ninputs + c:sum(2):eq(0):sum()
- local nleaves = noutputs + c:sum(3):eq(0):sum()
-
- tester:asserteq(#g.edges, nedges, 'wrong number of edges')
- tester:asserteq(#g.nodes, nnodes, 'wrong number of nodes')
- tester:asserteq(#g:roots(), nroots, 'wrong number of roots')
- tester:asserteq(#g:leaves(), nleaves, 'wrong number of leaves')
- end
+ local nlayers = torch.random(2,5)
+ local ninputs = torch.random(1,10)
+ local noutputs = torch.random(1,10)
+ local nhiddens = torch.random(10,20)
+ local droprates = {0, torch.uniform(0.2, 0.8), 1}
+ for i, droprate in ipairs(droprates) do
+ local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
+
+ local nedges = nhiddens * (ninputs+noutputs) + c:sum()
+ local nnodes = ninputs + noutputs + nhiddens*nlayers
+ local nroots = ninputs + c:sum(2):eq(0):sum()
+ local nleaves = noutputs + c:sum(3):eq(0):sum()
+
+ tester:asserteq(#g.edges, nedges, 'wrong number of edges')
+ tester:asserteq(#g.nodes, nnodes, 'wrong number of nodes')
+ tester:asserteq(#g:roots(), nroots, 'wrong number of roots')
+ tester:asserteq(#g:leaves(), nleaves, 'wrong number of leaves')
+ end
end
function tests.test_dfs()
- local nlayers = torch.random(5,10)
- local ninputs = 1
- local noutputs = 1
- local nhiddens = 1
- local droprate = 0
+ local nlayers = torch.random(5,10)
+ local ninputs = 1
+ local noutputs = 1
+ local nhiddens = 1
+ local droprate = 0
- local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
- local roots = g:roots()
- local leaves = g:leaves()
+ local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
+ local roots = g:roots()
+ local leaves = g:leaves()
- tester:asserteq(#roots, 1, 'expected a single root')
- tester:asserteq(#leaves, 1, 'expected a single leaf')
+ tester:asserteq(#roots, 1, 'expected a single root')
+ tester:asserteq(#leaves, 1, 'expected a single leaf')
- local dfs_nodes = {}
- roots[1]:dfs(function(node) table.insert(dfs_nodes, node) end)
+ local dfs_nodes = {}
+ roots[1]:dfs(function(node) table.insert(dfs_nodes, node) end)
- for i, node in ipairs(dfs_nodes) do
- tester:asserteq(node.data, #dfs_nodes - i +1, 'dfs order wrong')
- end
+ for i, node in ipairs(dfs_nodes) do
+ tester:asserteq(node.data, #dfs_nodes - i +1, 'dfs order wrong')
+ end
end
function tests.test_bfs()
- local nlayers = torch.random(5,10)
- local ninputs = 1
- local noutputs = 1
- local nhiddens = 1
- local droprate = 0
+ local nlayers = torch.random(5,10)
+ local ninputs = 1
+ local noutputs = 1
+ local nhiddens = 1
+ local droprate = 0
- local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
- local roots = g:roots()
- local leaves = g:leaves()
+ local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
+ local roots = g:roots()
+ local leaves = g:leaves()
- tester:asserteq(#roots, 1, 'expected a single root')
- tester:asserteq(#leaves, 1, 'expected a single leaf')
+ tester:asserteq(#roots, 1, 'expected a single root')
+ tester:asserteq(#leaves, 1, 'expected a single leaf')
- local bfs_nodes = {}
- roots[1]:bfs(function(node) table.insert(bfs_nodes, node) end)
+ local bfs_nodes = {}
+ roots[1]:bfs(function(node) table.insert(bfs_nodes, node) end)
- for i, node in ipairs(bfs_nodes) do
- tester:asserteq(node.data, i, 'bfs order wrong')
- end
+ for i, node in ipairs(bfs_nodes) do
+ tester:asserteq(node.data, i, 'bfs order wrong')
+ end
end
return tester:add(tests):run()
diff --git a/test/test_graphviz.lua b/test/test_graphviz.lua
index f0f15b2..1d344a9 100644
--- a/test/test_graphviz.lua
+++ b/test/test_graphviz.lua
@@ -5,33 +5,33 @@ local tester = totem.Tester()
local tests = {}
function tests.layout()
- local g = graph.Graph()
- local root = graph.Node(10)
- local n1 = graph.Node(1)
- local n2 = graph.Node(2)
- g:add(graph.Edge(root, n1))
- g:add(graph.Edge(n1, n2))
+ local g = graph.Graph()
+ local root = graph.Node(10)
+ local n1 = graph.Node(1)
+ local n2 = graph.Node(2)
+ g:add(graph.Edge(root, n1))
+ g:add(graph.Edge(n1, n2))
- local positions = graph.graphvizLayout(g, 'dot')
- local xs = positions:select(2, 1)
- local ys = positions:select(2, 2)
- tester:assertlt(xs:add(-xs:mean()):norm(), 1e-3,
- "x coordinates should be the same")
- tester:assertTensorEq(ys, torch.sort(ys, true), 1e-3,
- "y coordinates should be ordered")
+ local positions = graph.graphvizLayout(g, 'dot')
+ local xs = positions:select(2, 1)
+ local ys = positions:select(2, 2)
+ tester:assertlt(xs:add(-xs:mean()):norm(), 1e-3,
+ "x coordinates should be the same")
+ tester:assertTensorEq(ys, torch.sort(ys, true), 1e-3,
+ "y coordinates should be ordered")
end
function tests.testDotEscape()
- tester:assert(graph._dotEscape('red') == 'red', 'Don\'t escape single words')
- tester:assert(graph._dotEscape('My label') == '"My label"',
- 'Use quotes for spaces')
- tester:assert(graph._dotEscape('Non[an') == '"Non[an"',
- 'Use quotes for non-alpha characters')
- tester:assert(graph._dotEscape('My\nnewline') == '"My\\nnewline"',
- 'Escape newlines')
- tester:assert(graph._dotEscape('Say "hello"') == '"Say \\"hello\\""',
- 'Escape quotes')
+ tester:assert(graph._dotEscape('red') == 'red', 'Don\'t escape single words')
+ tester:assert(graph._dotEscape('My label') == '"My label"',
+ 'Use quotes for spaces')
+ tester:assert(graph._dotEscape('Non[an') == '"Non[an"',
+ 'Use quotes for non-alpha characters')
+ tester:assert(graph._dotEscape('My\nnewline') == '"My\\nnewline"',
+ 'Escape newlines')
+ tester:assert(graph._dotEscape('Say "hello"') == '"Say \\"hello\\""',
+ 'Escape quotes')
end