Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorJonathan J Hunt <jjhunt@google.com>2014-11-19 21:31:07 +0300
committerJonathan J Hunt <jjhunt@google.com>2015-01-14 15:20:09 +0300
commit6dc7823207a32fb93f3979bf659fb6cce71d1df5 (patch)
tree9990f3d3547759153576accd50b3de7ecbc9e3ea /test
parent518b60b632179b09db4b06c953c2ac4fc66097ce (diff)
Added support for annotating graphs (depends also on changes in graph).
Replaces node:name.
Diffstat (limited to 'test')
-rw-r--r--test/test_nngraph.lua38
1 files changed, 38 insertions, 0 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index 1193dc9..4062c5a 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -303,4 +303,42 @@ function test.test_resizeNestedAs()
checkGradients(net, input)
end
+
+function test.test_annotateGraph()
+ local input = nn.Identity()():annotate(
+ {name = 'Input', description = 'DescA',
+ graphAttributes = {color = 'red'}})
+
+ local hidden_a = nn.Linear(10, 10)(input):annotate(
+ {name = 'Hidden A', description = 'DescB',
+ graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}})
+ local hidden_b = nn.Sigmoid()(hidden_a)
+ local output = nn.Linear(10, 10)(hidden_b)
+ local net = nn.gModule({input}, {output})
+
+ tester:assert(hidden_a:label():match('DescB'))
+ local fg_tmpfile = os.tmpname()
+ local bg_tmpfile = os.tmpname()
+ graph.dot(net.fg, 'Test', fg_tmpfile)
+ graph.dot(net.fg, 'Test BG', bg_tmpfile)
+
+ local function checkDotFile(tmpfile)
+ local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all")
+ tester:assert(
+ dotcontent:match('%[label=%"Input.*DescA.*%" color=red%]'))
+ tester:assert(
+ dotcontent:match(
+ '%[label=%"Hidden A.*DescB.*%".*fontcolor=green.*%]'))
+ tester:assert(
+ dotcontent:match('%[label=%".*DescB.*%".*color=blue.*%]'))
+ tester:assert(
+ dotcontent:match(
+ '%[label=%".*DescB.*%".*tooltip=%".*test_nngraph.lua.*%".*%]'))
+ end
+
+ checkDotFile(fg_tmpfile)
+ checkDotFile(bg_tmpfile)
+end
+
+
tester:add(test):run()