diff options
author | Jonathan Tompson <tompson@cims.nyu.edu> | 2014-07-22 22:45:26 +0400 |
---|---|---|
committer | Jonathan Tompson <tompson@cims.nyu.edu> | 2014-07-22 22:45:26 +0400 |
commit | a76307c2a3537f5531cbbef88d649bb69b783d00 (patch) | |
tree | 660954f9a448b2b6e7cbf236c49823a1a2564d85 /Module.lua | |
parent | 419a934e84556325019155a5f23775a8b3551574 (diff) |
Made module type conversion for member variables recursive (so that tables of tensors are also converted).
Diffstat (limited to 'Module.lua')
-rw-r--r-- | Module.lua | 21 |
1 files changed, 19 insertions, 2 deletions
@@ -113,11 +113,28 @@ function Module:clone(...) return clone end +local function recursiveType(param, type_str) + if type(param) == 'table' then + for i = 1, #param do + param[i] = recursiveType(param[i], type_str) + end + else + if torch.typename(param) and + torch.typename(param):find('torch%..+Tensor') then + param = param:type(type_str) + end + end + return param +end + function Module:type(type) -- find all tensors and convert them for key,param in pairs(self) do - if torch.typename(param) and torch.typename(param):find('torch%..+Tensor') then - self[key] = param:type(type) + -- Many modules (like CDivTable) have output or gradInput fields which + -- are table's of tensors. To be general we need to recursively + -- cast fields that may be nested tables. + if key ~= 'modules' then + self[key] = recursiveType(self[key], type) end end -- find submodules in classic containers 'modules' |