Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSergio Gomez <sergomezcol@gmail.com>2014-06-25 22:23:17 +0400
committerSergio Gomez <sergomezcol@gmail.com>2014-06-25 22:29:28 +0400
commitd85c2ce7ab24a699855f41b1919a74f81def47cd (patch)
tree27fcdcfcf13c63b06610943e82a8798ef78d260c /test
parentea9cc1df751ddb144c08a13aab3add1ab0ce90a1 (diff)
Add minibatch support for nn.JoinTable and nn.SplitTable
The method setNumInputDims allows forwarding both minibatch and non-minibatch tensors through the same module. If this method is not used, the behaviour of these modules is the same as before.
Diffstat (limited to 'test')
-rw-r--r--test/test.lua39
1 files changed, 39 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index be17fd7..2db6f2d 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1790,6 +1790,45 @@ function nntest.LookupTable()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.JoinTable()
+ local tensor = torch.rand(3,4,5)
+ local input = {tensor, tensor}
+ local module
+ for d = 1,tensor:dim() do
+ module = nn.JoinTable(d)
+ mytester:asserteq(module:forward(input):size(d), tensor:size(d)*2, "dimension " .. d)
+ end
+
+ -- Minibatch
+ local tensor = torch.rand(3,4,5)
+ local input = {tensor, tensor}
+ local module
+ for d = 1,tensor:dim()-1 do
+ module = nn.JoinTable(d)
+ module:setNumInputDims(2)
+ mytester:asserteq(module:forward(input):size(d+1), tensor:size(d+1)*2, "dimension " .. d)
+ end
+end
+
+function nntest.SplitTable()
+ local input = torch.randn(3,4,5)
+ local module
+ for d = 1,input:dim() do
+ module = nn.SplitTable(d)
+ mytester:asserteq(#module:forward(input), input:size(d), "dimension " .. d)
+ end
+
+ -- Minibatch
+ local input = torch.randn(3,4,5)
+ local module
+ for d = 1,input:dim()-1 do
+ module = nn.SplitTable(d)
+ module:setNumInputDims(2)
+ mytester:asserteq(#module:forward(input), input:size(d+1), "dimension " .. d)
+ end
+end
+
+
mytester:add(nntest)
if not nn then