diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-25 23:35:36 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-25 23:35:36 +0400 |
commit | 590a77573fee782060177adfcd0afc97d3c30521 (patch) | |
tree | e06d0ad220803af975b83b2d7263457f5c3d20f1 /test | |
parent | 75568888fb18158925c7953729f29f90778059ab (diff) |
reshape works (unit tested)
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index b7b9fe8..6a7f2ae 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2122,6 +2122,25 @@ function nntest.View() "Error in minibatch nElement with size -1") end +function nntest.Reshape() + local input = torch.rand(10) + local template = torch.rand(5,2) + local target = template:size():totable() + local module = nn.Reshape(template:size()) + mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (1)") + local module = nn.View(unpack(target)) + mytester:assertTableEq(module:forward(input):size():totable(), target, "Error in forward (2)") + + -- Minibatch + local minibatch = torch.rand(5,10) + mytester:assertTableEq(module:forward(minibatch):size(1), + minibatch:size(1), + "Error in minibatch dimension") + mytester:assertTableEq(module:forward(minibatch):nElement(), + minibatch:nElement(), + "Error in minibatch nElement") +end + -- Define a test for SpatialUpSamplingCuda function nntest.SpatialUpSamplingNearest() local scale = torch.random(2,4) |