Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominik Grewe <dominikg@google.com>2015-07-27 13:03:19 +0300
committerDominik Grewe <dominikg@google.com>2015-07-27 13:03:19 +0300
commitace001045b2fdf7a8c2ebe69b7231e6d1c88ee65 (patch)
tree7df3b2188d05fa6d612cd24655c8dcba7d60f80f /utils.lua
parent328ce70b76581a4aa0706ede2ca9ee13348c6507 (diff)
Make addSingletonDimension() more robust and add tests.
Diffstat (limited to 'utils.lua')
-rw-r--r--utils.lua4
1 files changed, 4 insertions, 0 deletions
diff --git a/utils.lua b/utils.lua
index 74ff7e6..4d89568 100644
--- a/utils.lua
+++ b/utils.lua
@@ -66,6 +66,10 @@ function nn.utils.recursiveAdd(t1, val, t2)
end
function nn.utils.addSingletonDimension(t, dim)
+ assert(torch.isTensor(t), "input tensor expected")
+ local dim = dim or 1
+ assert(dim > 0 and dim <= t:dim(), "invalid dimension: " .. dim)
+
local view = t.new()
local size = torch.LongStorage(t:dim() + 1)
local stride = torch.LongStorage(t:dim() + 1)