Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2014-07-22 22:45:26 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2014-07-22 22:45:26 +0400
commita76307c2a3537f5531cbbef88d649bb69b783d00 (patch)
tree660954f9a448b2b6e7cbf236c49823a1a2564d85 /Module.lua
parent419a934e84556325019155a5f23775a8b3551574 (diff)
Made module type conversion for member variables recursive (so that tables of tensors are also converted).
Diffstat (limited to 'Module.lua')
-rw-r--r--Module.lua21
1 files changed, 19 insertions, 2 deletions
diff --git a/Module.lua b/Module.lua
index 4a08d12..df22a76 100644
--- a/Module.lua
+++ b/Module.lua
@@ -113,11 +113,28 @@ function Module:clone(...)
return clone
end
+local function recursiveType(param, type_str)
+ if type(param) == 'table' then
+ for i = 1, #param do
+ param[i] = recursiveType(param[i], type_str)
+ end
+ else
+ if torch.typename(param) and
+ torch.typename(param):find('torch%..+Tensor') then
+ param = param:type(type_str)
+ end
+ end
+ return param
+end
+
function Module:type(type)
-- find all tensors and convert them
for key,param in pairs(self) do
- if torch.typename(param) and torch.typename(param):find('torch%..+Tensor') then
- self[key] = param:type(type)
+ -- Many modules (like CDivTable) have output or gradInput fields which
+ -- are table's of tensors. To be general we need to recursively
+ -- cast fields that may be nested tables.
+ if key ~= 'modules' then
+ self[key] = recursiveType(self[key], type)
end
end
-- find submodules in classic containers 'modules'