Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/graph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2015-06-20 14:15:00 +0300
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2015-06-20 14:15:00 +0300
commitcc9e49a4cc34f265ba36d210b2c78b1f825cc4fe (patch)
tree730724ed353cdef08b97dfec027478e27f6f9c88
parent013e59a3670abf9e558f4749638f4aaac512bbe9 (diff)
move rockspec into rocks folder
mv tests into test folder add a proper test file
-rw-r--r--CMakeLists.txt4
-rw-r--r--rocks/graph-scm-1.rockspec (renamed from graph-scm-1.rockspec)0
-rw-r--r--test/test_graph.lua137
-rw-r--r--test/test_graphviz.lua (renamed from test_graphviz.lua)0
-rw-r--r--test/test_old.lua (renamed from test.lua)0
5 files changed, 138 insertions, 3 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3c8ce97..fcd91a3 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -3,8 +3,6 @@ CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR)
CMAKE_POLICY(VERSION 2.6)
FIND_PACKAGE(Torch REQUIRED)
-SET(luasrc init.lua Node.lua Edge.lua
- graphviz.lua
- test.lua)
+FILE(GLOB luasrc *.lua)
ADD_TORCH_PACKAGE(graph "" "${luasrc}" "General Graph Package")
diff --git a/graph-scm-1.rockspec b/rocks/graph-scm-1.rockspec
index c16b42a..c16b42a 100644
--- a/graph-scm-1.rockspec
+++ b/rocks/graph-scm-1.rockspec
diff --git a/test/test_graph.lua b/test/test_graph.lua
new file mode 100644
index 0000000..c5bcf25
--- /dev/null
+++ b/test/test_graph.lua
@@ -0,0 +1,137 @@
+
+require 'graph'
+require 'totem'
+
+local tester = totem.Tester()
+local tests = {}
+
+local function create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
+ local g = graph.Graph()
+ local conmat = torch.rand(nlayers, nhiddens, nhiddens):ge(droprate)[{ {1, -2}, {}, {} }]
+
+ -- create nodes
+ local nodes = { [0] = {}, [nlayers+1] = {} }
+ local nodecntr = 1
+ for inode = 1, ninputs do
+ local node = graph.Node(nodecntr)
+ nodes[0][inode] = node
+ nodecntr = nodecntr + 1
+ end
+ for ilayer = 1, nlayers do
+ nodes[ilayer] = {}
+ for inode = 1, nhiddens do
+ local node = graph.Node(nodecntr)
+ nodes[ilayer][inode] = node
+ nodecntr = nodecntr + 1
+ end
+ end
+ for inode = 1, noutputs do
+ local node = graph.Node(nodecntr)
+ nodes[nlayers+1][inode] = node
+ nodecntr = nodecntr + 1
+ end
+
+ -- now connect inputs to all first layer hiddens
+ for iinput = 1, ninputs do
+ for inode = 1, nhiddens do
+ g:add(graph.Edge(nodes[0][iinput], nodes[1][inode]))
+ end
+ end
+ -- now run through layers and connect them
+ for ilayer = 1, nlayers-1 do
+ for jnode = 1, nhiddens do
+ for knode = 1, nhiddens do
+ if conmat[ilayer][jnode][knode] == 1 then
+ g:add(graph.Edge(nodes[ilayer][jnode], nodes[ilayer+1][knode]))
+ end
+ end
+ end
+ end
+ -- now connect last layer hiddens to outputs
+ for inode = 1, nhiddens do
+ for ioutput = 1, noutputs do
+ g:add(graph.Edge(nodes[nlayers][inode], nodes[nlayers+1][ioutput]))
+ end
+ end
+
+ -- there might be nodes left out and not connected to anything. Connect them
+ for i = 1, nlayers do
+ for j = 1, nhiddens do
+ if not g.nodes[nodes[i][j]] then
+ local jto = torch.random(1, nhiddens)
+ g:add(graph.Edge(nodes[i][j], nodes[i+1][jto]))
+ conmat[i][j][jto] = 1
+ end
+ end
+ end
+
+ return g, conmat
+end
+
+
+function tests.graph()
+ local nlayers = torch.random(2,5)
+ local ninputs = torch.random(1,10)
+ local noutputs = torch.random(1,10)
+ local nhiddens = torch.random(10,20)
+ local droprates = {0, torch.uniform(0.2, 0.8), 1}
+ for i, droprate in ipairs(droprates) do
+ local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
+
+ local nedges = nhiddens * (ninputs+noutputs) + c:sum()
+ local nnodes = ninputs + noutputs + nhiddens*nlayers
+ local nroots = ninputs + c:sum(2):eq(0):sum()
+ local nleaves = noutputs + c:sum(3):eq(0):sum()
+
+ tester:asserteq(#g.edges, nedges, 'wrong number of edges')
+ tester:asserteq(#g.nodes, nnodes, 'wrong number of nodes')
+ tester:asserteq(#g:roots(), nroots, 'wrong number of roots')
+ tester:asserteq(#g:leaves(), nleaves, 'wrong number of leaves')
+ end
+end
+
+function tests.test_dfs()
+ local nlayers = torch.random(5,10)
+ local ninputs = 1
+ local noutputs = 1
+ local nhiddens = 1
+ local droprate = 0
+
+ local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
+ local roots = g:roots()
+ local leaves = g:leaves()
+
+ tester:asserteq(#roots, 1, 'expected a single root')
+ tester:asserteq(#leaves, 1, 'expected a single leaf')
+
+ local dfs_nodes = {}
+ roots[1]:dfs(function(node) table.insert(dfs_nodes, node) end)
+
+ for i, node in ipairs(dfs_nodes) do
+ tester:asserteq(node.data, #dfs_nodes - i +1, 'dfs order wrong')
+ end
+end
+
+function tests.test_bfs()
+ local nlayers = torch.random(5,10)
+ local ninputs = 1
+ local noutputs = 1
+ local nhiddens = 1
+ local droprate = 0
+
+ local g,c = create_graph(nlayers, ninputs, noutputs, nhiddens, droprate)
+ local roots = g:roots()
+ local leaves = g:leaves()
+
+ tester:asserteq(#roots, 1, 'expected a single root')
+ tester:asserteq(#leaves, 1, 'expected a single leaf')
+
+ local bfs_nodes = {}
+ roots[1]:bfs(function(node) table.insert(bfs_nodes, node) end)
+
+ for i, node in ipairs(bfs_nodes) do
+ tester:asserteq(node.data, i, 'bfs order wrong')
+ end
+end
+
+return tester:add(tests):run()
diff --git a/test_graphviz.lua b/test/test_graphviz.lua
index f0f15b2..f0f15b2 100644
--- a/test_graphviz.lua
+++ b/test/test_graphviz.lua
diff --git a/test.lua b/test/test_old.lua
index 003234c..003234c 100644
--- a/test.lua
+++ b/test/test_old.lua