diff options
author | Nicholas Leonard <nick@nikopia.org> | 2015-07-25 22:51:29 +0300 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2015-07-25 22:51:29 +0300 |
commit | 73ddbe49b8a58eca739e3f442bcb101b6aff0978 (patch) | |
tree | cb91414d9c918f3b5d6e865f882d0529f4f8e82d /utils.lua | |
parent | 3c13b0d41281990e3f1d4dd8a1ae45b8d5ca0f40 (diff) |
[Parallel,Multi]Criterion table inputs
Diffstat (limited to 'utils.lua')
-rw-r--r-- | utils.lua | 21 |
1 files changed, 21 insertions, 0 deletions
@@ -44,6 +44,27 @@ function nn.utils.recursiveFill(t2, val) return t2 end +function nn.utils.recursiveAdd(t1, val, t2) + if not t2 then + assert(val, "expecting at least two arguments") + t2 = val + val = 1 + end + val = val or 1 + if torch.type(t2) == 'table' then + t1 = (torch.type(t1) == 'table') and t1 or {t1} + for key,_ in pairs(t2) do + t1[key], t2[key] = nn.utils.recursiveAdd(t1[key], val, t2[key]) + end + elseif torch.isTensor(t2) and torch.isTensor(t2) then + t1:add(val, t2) + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end + return t1, t2 +end + function nn.utils.addSingletonDimension(t, dim) local view = t.new() local size = torch.LongStorage(t:dim() + 1) |