diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-06-26 17:19:53 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-06-26 17:19:53 +0400 |
commit | 896ad1c1bf5588b2944c79fb24a0aee1ae7db726 (patch) | |
tree | 2d23facf5bcc965dda0b6b02c31d5e7f0d7d6bfe /test | |
parent | 4a368c6616da3698fd419fc5c6a2adffcffe5ae0 (diff) | |
parent | 8fd02c336bc2d03aa1fe21fced69615fd1e7b99b (diff) |
Merge pull request #17 from sergomezcol/master
Add minibatch support for nn.JoinTable and nn.SplitTable
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 775dded..3c4ec12 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1830,6 +1830,43 @@ function nntest.MulConstant() mytester:assertlt(err, precision, 'bprop error ') end +function nntest.JoinTable() + local tensor = torch.rand(3,4,5) + local input = {tensor, tensor} + local module + for d = 1,tensor:dim() do + module = nn.JoinTable(d) + mytester:asserteq(module:forward(input):size(d), tensor:size(d)*2, "dimension " .. d) + end + + -- Minibatch + local tensor = torch.rand(3,4,5) + local input = {tensor, tensor} + local module + for d = 1,tensor:dim()-1 do + module = nn.JoinTable(d, 2) + mytester:asserteq(module:forward(input):size(d+1), tensor:size(d+1)*2, "dimension " .. d) + end +end + +function nntest.SplitTable() + local input = torch.randn(3,4,5) + local module + for d = 1,input:dim() do + module = nn.SplitTable(d) + mytester:asserteq(#module:forward(input), input:size(d), "dimension " .. d) + end + + -- Minibatch + local input = torch.randn(3,4,5) + local module + for d = 1,input:dim()-1 do + module = nn.SplitTable(d, 2) + mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d) + end +end + + mytester:add(nntest) if not nn then |