From d0da63e83825d4631a7f299766a1fa968eb5ccd3 Mon Sep 17 00:00:00 2001 From: Dominik Grewe Date: Tue, 28 Apr 2015 16:07:32 +0100 Subject: Make type() truly recursive. Recursively iterate over the whole table, converting each tensor to the given type. Removes need for many specialized type() functions. --- Module.lua | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) (limited to 'Module.lua') diff --git a/Module.lua b/Module.lua index d3f5a26..d6b16fb 100644 --- a/Module.lua +++ b/Module.lua @@ -113,17 +113,6 @@ function Module:clone(...) return clone end -local function recursiveType(param, type_str) - if torch.type(param) == 'table' then - for i = 1, #param do - param[i] = recursiveType(param[i], type_str) - end - elseif torch.isTensor(param) then - param = param:type(type_str) - end - return param -end - function Module:type(type) assert(type, 'Module: must provide a type to convert to') -- find all tensors and convert them @@ -132,7 +121,7 @@ function Module:type(type) -- are table's of tensors. To be general we need to recursively -- cast fields that may be nested tables. if key ~= 'modules' then - self[key] = recursiveType(self[key], type) + self[key] = nn._utils.recursiveType(param, type) end end -- find submodules in classic containers 'modules' -- cgit v1.2.3