Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-07-27 13:26:30 +0300
committerSoumith Chintala <soumith@gmail.com>2015-07-27 13:26:30 +0300
commit1115a7d492875b87c88aa9673bd73230b1b4598f (patch)
tree7df3b2188d05fa6d612cd24655c8dcba7d60f80f
parent328ce70b76581a4aa0706ede2ca9ee13348c6507 (diff)
parentace001045b2fdf7a8c2ebe69b7231e6d1c88ee65 (diff)
Merge pull request #331 from dominikgrewe/singleton
Make addSingletonDimension() more robust and add tests.
-rw-r--r--test.lua26
-rw-r--r--utils.lua4
2 files changed, 30 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index e198227..f5e5295 100644
--- a/test.lua
+++ b/test.lua
@@ -3993,6 +3993,32 @@ function nntest.Padding()
mytester:assertTensorEq(gradInput, input, 0.00001, "Padding backward error")
end
+function nntest.addSingletonDimension()
+ local dims = torch.random(5)
+ local size = torch.LongTensor(dims):random(10)
+ local perm = torch.randperm(dims):totable()
+ local tensor = torch.Tensor(unpack(size:totable())):uniform():permute(unpack(perm))
+ size = torch.gather(size, 1, torch.LongTensor(perm))
+
+ local firstDim = nn.utils.addSingletonDimension(tensor)
+ mytester:assertTableEq(firstDim:size():totable(), {1, unpack(size:totable())},
+ "wrong size for singleton dimension 1")
+ mytester:assertTensorEq(firstDim[1], tensor, 0,
+ "wrong content for singleton dimension 1")
+
+ local dim = torch.random(dims)
+ local result = nn.utils.addSingletonDimension(tensor, dim)
+ local resultSize = size:totable()
+ table.insert(resultSize, dim, 1)
+ mytester:assertTableEq(result:size():totable(), resultSize,
+ "wrong size for random singleton dimension")
+ mytester:assertTensorEq(result:select(dim, 1), tensor, 0,
+ "wrong content for random singleton dimension")
+
+ mytester:assertError(function() nn.utils.addSingletonDimension(tensor, dims + 1) end,
+ "invalid dimension not detected")
+end
+
mytester:add(nntest)
if not nn then
diff --git a/utils.lua b/utils.lua
index 74ff7e6..4d89568 100644
--- a/utils.lua
+++ b/utils.lua
@@ -66,6 +66,10 @@ function nn.utils.recursiveAdd(t1, val, t2)
end
function nn.utils.addSingletonDimension(t, dim)
+ assert(torch.isTensor(t), "input tensor expected")
+ local dim = dim or 1
+ assert(dim > 0 and dim <= t:dim(), "invalid dimension: " .. dim)
+
local view = t.new()
local size = torch.LongStorage(t:dim() + 1)
local stride = torch.LongStorage(t:dim() + 1)