diff options
author | Sergio Gomez <sergomezcol@gmail.com> | 2014-06-25 22:23:17 +0400 |
---|---|---|
committer | Sergio Gomez <sergomezcol@gmail.com> | 2014-06-25 22:29:28 +0400 |
commit | d85c2ce7ab24a699855f41b1919a74f81def47cd (patch) | |
tree | 27fcdcfcf13c63b06610943e82a8798ef78d260c /test | |
parent | ea9cc1df751ddb144c08a13aab3add1ab0ce90a1 (diff) |
Add minibatch support for nn.JoinTable and nn.SplitTable
The method setNumInputDims allows forwarding both minibatch and
non-minibatch tensors through the same module.
If this method is not used, the behaviour of these modules is the
same as before.
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index be17fd7..2db6f2d 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1790,6 +1790,45 @@ function nntest.LookupTable() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end +function nntest.JoinTable() + local tensor = torch.rand(3,4,5) + local input = {tensor, tensor} + local module + for d = 1,tensor:dim() do + module = nn.JoinTable(d) + mytester:asserteq(module:forward(input):size(d), tensor:size(d)*2, "dimension " .. d) + end + + -- Minibatch + local tensor = torch.rand(3,4,5) + local input = {tensor, tensor} + local module + for d = 1,tensor:dim()-1 do + module = nn.JoinTable(d) + module:setNumInputDims(2) + mytester:asserteq(module:forward(input):size(d+1), tensor:size(d+1)*2, "dimension " .. d) + end +end + +function nntest.SplitTable() + local input = torch.randn(3,4,5) + local module + for d = 1,input:dim() do + module = nn.SplitTable(d) + mytester:asserteq(#module:forward(input), input:size(d), "dimension " .. d) + end + + -- Minibatch + local input = torch.randn(3,4,5) + local module + for d = 1,input:dim()-1 do + module = nn.SplitTable(d) + module:setNumInputDims(2) + mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d) + end +end + + mytester:add(nntest) if not nn then |