diff options
author | Malcolm Reynolds <mareynolds@google.com> | 2016-01-20 17:34:47 +0300 |
---|---|---|
committer | Malcolm Reynolds <mareynolds@google.com> | 2016-01-20 17:34:47 +0300 |
commit | 5448b95c9ee02f28adc0ef424071023ff17e8aa4 (patch) | |
tree | 3565ce467b898b0d1f92f56a4cd493e40e9a973b /test | |
parent | 6cd17c40f4cbc426d20894a58d81360724253333 (diff) |
Store a reverse mapping when wiring together graph, detect unused nodes.
The connectivity checking code was previously unable to detect the
following error case:
local input = nn.Identity()()
local usedOutput = nn.Linear(20, 10)(input)
local unusedOutput = nn.Linear(20, 10)(input)
local gmod = nn.gModule({input}, {usedOutput})
With this fix, when gModule is called it will throw an error, because of
unusedOutput.
This is a backwards incompatible change, but I feel that the current
flexibility is error prone, and I can't see any advantage to it. We have
flushed out a couple of bugs in internal code with this change.
Diffstat (limited to 'test')
-rw-r--r-- | test/test_connectivity.lua | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/test/test_connectivity.lua b/test/test_connectivity.lua new file mode 100644 index 0000000..99b2539 --- /dev/null +++ b/test/test_connectivity.lua @@ -0,0 +1,26 @@ +local totem = require 'totem' +require 'nngraph' +local tests = totem.TestSuite() +local tester = totem.Tester() + +function tests.connectivity() + -- Store debug info here, need to call debug.getinfo on same line as the + -- dangling pointer is declared. + local dInfo + local input = nn.Identity()() + local lin = nn.Linear(20, 10)(input) + -- The Sigmoid does not connect to the output, so should cause an error + -- when we call gModule. + local dangling = nn.Sigmoid()(lin); dInfo = debug.getinfo(1, 'Sl') + local actualOutput = nn.Tanh()(lin) + local errStr = string.format( + 'node declared on %%[%s%%]:%d_ does not connect to gmodule output', + dInfo.short_src, dInfo.currentline) + tester:assertErrorPattern( + function() + return nn.gModule({input}, {actualOutput}) + end, + errStr) +end + +return tester:add(tests):run() |