diff options
Diffstat (limited to 'Module.lua')
-rw-r--r-- | Module.lua | 13 |
1 files changed, 1 insertions, 12 deletions
@@ -113,17 +113,6 @@ function Module:clone(...) return clone end -local function recursiveType(param, type_str) - if torch.type(param) == 'table' then - for i = 1, #param do - param[i] = recursiveType(param[i], type_str) - end - elseif torch.isTensor(param) then - param = param:type(type_str) - end - return param -end - function Module:type(type) assert(type, 'Module: must provide a type to convert to') -- find all tensors and convert them @@ -132,7 +121,7 @@ function Module:type(type) -- are table's of tensors. To be general we need to recursively -- cast fields that may be nested tables. if key ~= 'modules' then - self[key] = recursiveType(self[key], type) + self[key] = nn._utils.recursiveType(param, type) end end -- find submodules in classic containers 'modules' |