Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2015-07-25 22:51:29 +0300
committerNicholas Leonard <nick@nikopia.org>2015-07-25 22:51:29 +0300
commit73ddbe49b8a58eca739e3f442bcb101b6aff0978 (patch)
treecb91414d9c918f3b5d6e865f882d0529f4f8e82d /utils.lua
parent3c13b0d41281990e3f1d4dd8a1ae45b8d5ca0f40 (diff)
[Parallel,Multi]Criterion table inputs
Diffstat (limited to 'utils.lua')
-rw-r--r--utils.lua21
1 files changed, 21 insertions, 0 deletions
diff --git a/utils.lua b/utils.lua
index 78a22a2..74ff7e6 100644
--- a/utils.lua
+++ b/utils.lua
@@ -44,6 +44,27 @@ function nn.utils.recursiveFill(t2, val)
return t2
end
+function nn.utils.recursiveAdd(t1, val, t2)
+ if not t2 then
+ assert(val, "expecting at least two arguments")
+ t2 = val
+ val = 1
+ end
+ val = val or 1
+ if torch.type(t2) == 'table' then
+ t1 = (torch.type(t1) == 'table') and t1 or {t1}
+ for key,_ in pairs(t2) do
+ t1[key], t2[key] = nn.utils.recursiveAdd(t1[key], val, t2[key])
+ end
+ elseif torch.isTensor(t2) and torch.isTensor(t2) then
+ t1:add(val, t2)
+ else
+ error("expecting nested tensors or tables. Got "..
+ torch.type(t1).." and "..torch.type(t2).." instead")
+ end
+ return t1, t2
+end
+
function nn.utils.addSingletonDimension(t, dim)
local view = t.new()
local size = torch.LongStorage(t:dim() + 1)