Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorIvo Danihelka <danihelka@google.com>2015-07-23 12:27:56 +0300
committerIvo Danihelka <danihelka@google.com>2015-07-23 12:27:56 +0300
commite9454c5fae2cdae12001a38a021a1d0c1894bfd1 (patch)
treec1e78b9a46a7921c5c6d1c50d17020ff3d29c367 /test
parentd3328034c7fcb020f590249779750b6d25d75110 (diff)
Added an assert to check the number of inputs to a split.
Diffstat (limited to 'test')
-rw-r--r--test/test_nngraph.lua21
1 files changed, 21 insertions, 0 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index 6698b0d..86bf730 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -358,5 +358,26 @@ function test.test_annotateGraph()
checkDotFile(bg_tmpfile)
end
+function test.test_splitMore()
+ local nSplits = 2
+ local in1 = nn.Identity()()
+ local out1, out2 = nn.SplitTable(2)(in1):split(nSplits)
+
+ local model = nn.gModule({in1}, {out1, out2})
+ local input = torch.randn(10, nSplits + 1)
+ local ok, result = pcall(model.forward, model, input)
+ assert(not ok, "the extra input to split should be detected")
+end
+
+function test.test_splitLess()
+ local nSplits = 3
+ local in1 = nn.Identity()()
+ local out1, out2, out3 = nn.SplitTable(2)(in1):split(nSplits)
+
+ local model = nn.gModule({in1}, {out1, out2, out3})
+ local input = torch.randn(10, nSplits - 1)
+ local ok, result = pcall(model.forward, model, input)
+ assert(not ok, "the missing input to split should be detected")
+end
tester:add(test):run()