diff options
author | Ivo Danihelka <danihelka@google.com> | 2015-07-23 12:27:56 +0300 |
---|---|---|
committer | Ivo Danihelka <danihelka@google.com> | 2015-07-23 12:27:56 +0300 |
commit | e9454c5fae2cdae12001a38a021a1d0c1894bfd1 (patch) | |
tree | c1e78b9a46a7921c5c6d1c50d17020ff3d29c367 /test | |
parent | d3328034c7fcb020f590249779750b6d25d75110 (diff) |
Added an assert to check the number of inputs to a split.
Diffstat (limited to 'test')
-rw-r--r-- | test/test_nngraph.lua | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index 6698b0d..86bf730 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -358,5 +358,26 @@ function test.test_annotateGraph() checkDotFile(bg_tmpfile) end +function test.test_splitMore() + local nSplits = 2 + local in1 = nn.Identity()() + local out1, out2 = nn.SplitTable(2)(in1):split(nSplits) + + local model = nn.gModule({in1}, {out1, out2}) + local input = torch.randn(10, nSplits + 1) + local ok, result = pcall(model.forward, model, input) + assert(not ok, "the extra input to split should be detected") +end + +function test.test_splitLess() + local nSplits = 3 + local in1 = nn.Identity()() + local out1, out2, out3 = nn.SplitTable(2)(in1):split(nSplits) + + local model = nn.gModule({in1}, {out1, out2, out3}) + local input = torch.randn(10, nSplits - 1) + local ok, result = pcall(model.forward, model, input) + assert(not ok, "the missing input to split should be detected") +end tester:add(test):run() |