Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/graph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2015-09-11 17:27:10 +0300
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2015-09-11 17:27:10 +0300
commit89bbb5a5aa32e6ba012f1d03812cfde1d9024410 (patch)
treecd6b9c0063bd8626aa209a6fb4ec99a325c05098
parentb22cd99d995c3a9731ee88f9f7bfc0046b87f5a0 (diff)
parent4696bdd0cfec20dcb7f552f5f82e3bb4a84d1bab (diff)
Merge pull request #20 from torch/topsort
improve topsort so that it can handle multiple leaf/root graphs
-rw-r--r--init.lua19
-rw-r--r--test/test_graph.lua91
-rw-r--r--test/test_old.lua2
3 files changed, 107 insertions, 5 deletions
diff --git a/init.lua b/init.lua
index 339345a..7ccec55 100644
--- a/init.lua
+++ b/init.lua
@@ -103,10 +103,11 @@ end
--[[
Topological Sort
-** This is not finished. OK for graphs with single root.
]]--
function Graph:topsort()
+ local dummyRoot
+
-- reverse the graph
local rg,map = self:reverse()
local rmap = {}
@@ -122,14 +123,24 @@ function Graph:topsort()
error('Graph has cycles')
end
- -- run
- for i,root in ipairs(rootnodes) do
- root:dfs(function(node) table.insert(sortednodes,rmap[node]) end)
+ if #rootnodes > 1 then
+ dummyRoot = graph.Node('dummy_root')
+ for _, root in ipairs(rootnodes) do
+ dummyRoot:add(root)
+ end
+ else
+ dummyRoot = rootnodes[1]
end
+ -- run
+ -- the trick is since the dummy node does not exist in original graph,
+ -- rmap[dummyRoot] = nil hence nothing gets inserted into the table
+ dummyRoot:dfs(function(node) table.insert(sortednodes,rmap[node]) end)
+
if #sortednodes ~= #self.nodes then
error('Graph has cycles')
end
+
return sortednodes,rg,rootnodes
end
diff --git a/test/test_graph.lua b/test/test_graph.lua
index 3967a85..7f33581 100644
--- a/test/test_graph.lua
+++ b/test/test_graph.lua
@@ -134,6 +134,71 @@ function tests.test_bfs()
end
end
+function tests.test_topsort()
+ local n1 = graph.Node(1)
+ local n2 = graph.Node(2)
+ local n3 = graph.Node(3)
+ local n4 = graph.Node(4)
+ local g = graph.Graph()
+ g:add(graph.Edge(n1, n2))
+ g:add(graph.Edge(n1, n3))
+ g:add(graph.Edge(n2, n3))
+ g:add(graph.Edge(n2, n4))
+ g:add(graph.Edge(n3, n4))
+
+ local sorted = g:topsort()
+ tester:assert(sorted[1] == n1, 'wrong sort order' )
+ tester:assert(sorted[2] == n2, 'wrong sort order' )
+ tester:assert(sorted[3] == n3, 'wrong sort order' )
+ tester:assert(sorted[4] == n4, 'wrong sort order' )
+
+
+ -- add an extra root
+ local n0 = graph.Node(0)
+ g:add(graph.Edge(n0, n2))
+ local sorted2 = g:topsort()
+ tester:assert(sorted2[1] == n1 or sorted2[1] == n0, 'wrong sort order' )
+ tester:assert(sorted2[5] == n4, 'wrong sort order' )
+
+ -- add an extra leaf
+ local n5 = graph.Node(5)
+ g:add(graph.Edge(n3, n5))
+ local sorted2 = g:topsort()
+ tester:assert(sorted2[1] == n1 or sorted2[1] == n0, 'wrong sort order' )
+ tester:assert(sorted2[6] == n4 or sorted2[6] == n5, 'wrong sort order' )
+ tester:assert(sorted2[5] == n4 or sorted2[5] == n5, 'wrong sort order' )
+ tester:assert(sorted2[6] ~= sorted2[5], 'wrong sort order' )
+
+
+ -- add a bottleneck and a new set of nodes
+ local n11 = graph.Node(11)
+ local n12 = graph.Node(12)
+ local n13 = graph.Node(13)
+ local n14 = graph.Node(14)
+ local n15 = graph.Node(15)
+ local n16 = graph.Node(16)
+
+ g:add(graph.Edge(n4, n11))
+ g:add(graph.Edge(n5, n11))
+ g:add(graph.Edge(n11, n12))
+ g:add(graph.Edge(n11, n13))
+ g:add(graph.Edge(n12, n13))
+ g:add(graph.Edge(n13, n14))
+ g:add(graph.Edge(n14, n15))
+ g:add(graph.Edge(n12, n15))
+ g:add(graph.Edge(n13, n16))
+
+ local sorted3 = g:topsort()
+ -- check all the first 6 sorted elements have data <= 5
+ for i=1, 6 do
+ tester:assert(sorted3[i].data <= 5, 'wrong sort order')
+ end
+ tester:assert(sorted3[7] == n11, 'wrong sort order')
+ tester:assert(sorted3[8] == n12, 'wrong sort order' )
+ tester:assert(sorted3[9] == n13, 'wrong sort order' )
+ tester:assert(sorted3[11] == n16 or sorted3[12] == n16, 'wrong sort order')
+end
+
function tests.test_cycle()
local n1 = graph.Node(1)
local n2 = graph.Node(2)
@@ -161,6 +226,32 @@ function tests.test_cycle()
nocycle:add(graph.Edge(n3, n4))
tester:asserteq(nocycle:hasCycle(), false, 'Graph is not supposed to have cycle')
+
+ local function create_cycle(g, node0, length)
+ local node1, node2 = node0, nil
+ for i = 1, length-1 do
+ node2 = graph.Node('c' .. i)
+ local e = graph.Edge(node1, node2)
+ g:add(e)
+ node1 = node2
+ end
+ g:add(graph.Edge(node1, node0))
+ end
+
+ local bigcycle = graph.Graph()
+ local n1 = graph.Node(1)
+ local n2 = graph.Node(2)
+ local n3 = graph.Node(3)
+ local n4 = graph.Node(4)
+ bigcycle:add(graph.Edge(n1, n2))
+ bigcycle:add(graph.Edge(n1, n3))
+ bigcycle:add(graph.Edge(n2, n3))
+ bigcycle:add(graph.Edge(n2, n4))
+ bigcycle:add(graph.Edge(n3, n4))
+ create_cycle(bigcycle, n2, 5)
+
+ tester:asserteq(cycle:hasCycle(), true, 'Graph is supposed to have cycle')
+
end
return tester:add(tests):run()
diff --git a/test/test_old.lua b/test/test_old.lua
index 003234c..17ae596 100644
--- a/test/test_old.lua
+++ b/test/test_old.lua
@@ -22,4 +22,4 @@ root:dfs(function(node) i=i+1;print('i='..i);print(node:label())end)
print('======= topsort ==========')
s,rg,rn = g:topsort()
-graph.dot(g)
+graph.dot(g, 'g', 'g')