Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMalcolm Reynolds <mareynolds@google.com>2016-01-20 17:34:47 +0300
committerMalcolm Reynolds <mareynolds@google.com>2016-01-20 17:34:47 +0300
commit5448b95c9ee02f28adc0ef424071023ff17e8aa4 (patch)
tree3565ce467b898b0d1f92f56a4cd493e40e9a973b /test
parent6cd17c40f4cbc426d20894a58d81360724253333 (diff)
Store a reverse mapping when wiring together graph, detect unused nodes.
The connectivity checking code was previously unable to detect the following error case: local input = nn.Identity()() local usedOutput = nn.Linear(20, 10)(input) local unusedOutput = nn.Linear(20, 10)(input) local gmod = nn.gModule({input}, {usedOutput}) With this fix, when gModule is called it will throw an error, because of unusedOutput. This is a backwards incompatible change, but I feel that the current flexibility is error prone, and I can't see any advantage to it. We have flushed out a couple of bugs in internal code with this change.
Diffstat (limited to 'test')
-rw-r--r--test/test_connectivity.lua26
1 files changed, 26 insertions, 0 deletions
diff --git a/test/test_connectivity.lua b/test/test_connectivity.lua
new file mode 100644
index 0000000..99b2539
--- /dev/null
+++ b/test/test_connectivity.lua
@@ -0,0 +1,26 @@
+local totem = require 'totem'
+require 'nngraph'
+local tests = totem.TestSuite()
+local tester = totem.Tester()
+
+function tests.connectivity()
+ -- Store debug info here, need to call debug.getinfo on same line as the
+ -- dangling pointer is declared.
+ local dInfo
+ local input = nn.Identity()()
+ local lin = nn.Linear(20, 10)(input)
+ -- The Sigmoid does not connect to the output, so should cause an error
+ -- when we call gModule.
+ local dangling = nn.Sigmoid()(lin); dInfo = debug.getinfo(1, 'Sl')
+ local actualOutput = nn.Tanh()(lin)
+ local errStr = string.format(
+ 'node declared on %%[%s%%]:%d_ does not connect to gmodule output',
+ dInfo.short_src, dInfo.currentline)
+ tester:assertErrorPattern(
+ function()
+ return nn.gModule({input}, {actualOutput})
+ end,
+ errStr)
+end
+
+return tester:add(tests):run()