diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index be17fd7..5a127db 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1790,6 +1790,25 @@ function nntest.LookupTable() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end +function nntest.View() + local input = torch.rand(10) + local template = torch.rand(5,2) + local target = template:size():totable() + local module = nn.View(template:size()) + mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (1)") + local module = nn.View(unpack(target)) + mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (2)") + + -- Minibatch + local minibatch = torch.rand(5,10) + mytester:assertTableEq(module:forward(minibatch):size(1), + minibatch:size(1), + "Error in minibatch dimension") + mytester:assertTableEq(module:forward(minibatch):nElement(), + minibatch:nElement(), + "Error in minibatch nElement") +end + mytester:add(nntest) if not nn then |