diff options
author | Sergio Gomez <sergomezcol@gmail.com> | 2014-06-26 15:06:01 +0400 |
---|---|---|
committer | Sergio Gomez <sergomezcol@gmail.com> | 2014-06-26 15:06:01 +0400 |
commit | 9386e79b7eaf324c34ec1c16fbc873add39dff22 (patch) | |
tree | 9cf3ad29de027ceaf709c31dbded1dca5baf5a27 /test | |
parent | ea9cc1df751ddb144c08a13aab3add1ab0ce90a1 (diff) |
Add nn.View module
This module creates a new view of the input tensor without copying it.
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index be17fd7..5a127db 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1790,6 +1790,25 @@ function nntest.LookupTable() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end +function nntest.View() + local input = torch.rand(10) + local template = torch.rand(5,2) + local target = template:size():totable() + local module = nn.View(template:size()) + mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (1)") + local module = nn.View(unpack(target)) + mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (2)") + + -- Minibatch + local minibatch = torch.rand(5,10) + mytester:assertTableEq(module:forward(minibatch):size(1), + minibatch:size(1), + "Error in minibatch dimension") + mytester:assertTableEq(module:forward(minibatch):nElement(), + minibatch:nElement(), + "Error in minibatch nElement") +end + mytester:add(nntest) if not nn then |