Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-05-18 19:44:28 +0300
committernicholas-leonard <nick@nikopia.org>2015-05-22 02:02:45 +0300
commit1e7eec6f5c91c40b6dfa7509ae928bf3a54ba0d2 (patch)
tree0f9c8cc95080581ce72d6d98aafe7305a939d4a9 /utils.lua
parentb3f7bcc4120dd99a18088a784aa081f9e1688879 (diff)
MixtureTable lazy initialized buffers
Diffstat (limited to 'utils.lua')
-rw-r--r--utils.lua17
1 files changed, 17 insertions, 0 deletions
diff --git a/utils.lua b/utils.lua
index 489358c..a2bb46b 100644
--- a/utils.lua
+++ b/utils.lua
@@ -14,4 +14,21 @@ function nn.utils.recursiveType(param, type_str)
return param
end
+function nn.utils.recursiveResizeAs(t1,t2)
+ if torch.type(t2) == 'table' then
+ t1 = (torch.type(t1) == 'table') and t1 or {t1}
+ for key,_ in pairs(t2) do
+ t1[key], t2[key] = nn.utils.recursiveResizeAs(t1[key], t2[key])
+ end
+ elseif torch.isTensor(t2) then
+ t1 = torch.isTensor(t1) and t1 or t2.new()
+ t1:resizeAs(t2)
+ else
+ error("expecting nested tensors or tables. Got "..
+ torch.type(t1).." and "..torch.type(t2).." instead")
+ end
+ return t1, t2
+end
+
+
table.unpack = table.unpack or unpack