diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-07-27 13:26:30 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-07-27 13:26:30 +0300 |
commit | 1115a7d492875b87c88aa9673bd73230b1b4598f (patch) | |
tree | 7df3b2188d05fa6d612cd24655c8dcba7d60f80f | |
parent | 328ce70b76581a4aa0706ede2ca9ee13348c6507 (diff) | |
parent | ace001045b2fdf7a8c2ebe69b7231e6d1c88ee65 (diff) |
Merge pull request #331 from dominikgrewe/singleton
Make addSingletonDimension() more robust and add tests.
-rw-r--r-- | test.lua | 26 | ||||
-rw-r--r-- | utils.lua | 4 |
2 files changed, 30 insertions, 0 deletions
@@ -3993,6 +3993,32 @@ function nntest.Padding() mytester:assertTensorEq(gradInput, input, 0.00001, "Padding backward error") end +function nntest.addSingletonDimension() + local dims = torch.random(5) + local size = torch.LongTensor(dims):random(10) + local perm = torch.randperm(dims):totable() + local tensor = torch.Tensor(unpack(size:totable())):uniform():permute(unpack(perm)) + size = torch.gather(size, 1, torch.LongTensor(perm)) + + local firstDim = nn.utils.addSingletonDimension(tensor) + mytester:assertTableEq(firstDim:size():totable(), {1, unpack(size:totable())}, + "wrong size for singleton dimension 1") + mytester:assertTensorEq(firstDim[1], tensor, 0, + "wrong content for singleton dimension 1") + + local dim = torch.random(dims) + local result = nn.utils.addSingletonDimension(tensor, dim) + local resultSize = size:totable() + table.insert(resultSize, dim, 1) + mytester:assertTableEq(result:size():totable(), resultSize, + "wrong size for random singleton dimension") + mytester:assertTensorEq(result:select(dim, 1), tensor, 0, + "wrong content for random singleton dimension") + + mytester:assertError(function() nn.utils.addSingletonDimension(tensor, dims + 1) end, + "invalid dimension not detected") +end + mytester:add(nntest) if not nn then @@ -66,6 +66,10 @@ function nn.utils.recursiveAdd(t1, val, t2) end function nn.utils.addSingletonDimension(t, dim) + assert(torch.isTensor(t), "input tensor expected") + local dim = dim or 1 + assert(dim > 0 and dim <= t:dim(), "invalid dimension: " .. dim) + local view = t.new() local size = torch.LongStorage(t:dim() + 1) local stride = torch.LongStorage(t:dim() + 1) |