diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-06-26 18:14:13 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-06-26 18:14:13 +0400 |
commit | 1310a045ebc69a9f9e8c57d07af587a6535d5ae9 (patch) | |
tree | 46e52e5ae16e0653953424423aa81c81dd6b526f /test | |
parent | 896ad1c1bf5588b2944c79fb24a0aee1ae7db726 (diff) | |
parent | 9386e79b7eaf324c34ec1c16fbc873add39dff22 (diff) |
Merge pull request #20 from sergomezcol/view_module
Add nn.View module
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 3c4ec12..5db941a 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1867,6 +1867,25 @@ function nntest.SplitTable() end +function nntest.View() + local input = torch.rand(10) + local template = torch.rand(5,2) + local target = template:size():totable() + local module = nn.View(template:size()) + mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (1)") + local module = nn.View(unpack(target)) + mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (2)") + + -- Minibatch + local minibatch = torch.rand(5,10) + mytester:assertTableEq(module:forward(minibatch):size(1), + minibatch:size(1), + "Error in minibatch dimension") + mytester:assertTableEq(module:forward(minibatch):nElement(), + minibatch:nElement(), + "Error in minibatch nElement") +end + mytester:add(nntest) if not nn then |