Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominik Grewe <dominikg@google.com>2015-04-28 18:07:32 +0300
committerDominik Grewe <dominikg@google.com>2015-04-28 18:36:55 +0300
commitd0da63e83825d4631a7f299766a1fa968eb5ccd3 (patch)
treeedab79c4cf664ef0dac24fb851923059027e659e /Module.lua
parent485dd619695c47e49ab56ff518edad52b74475fc (diff)
Make type() truly recursive.
Recursively iterate over the whole table, converting each tensor to the given type. Removes need for many specialized type() functions.
Diffstat (limited to 'Module.lua')
-rw-r--r--Module.lua13
1 files changed, 1 insertions, 12 deletions
diff --git a/Module.lua b/Module.lua
index d3f5a26..d6b16fb 100644
--- a/Module.lua
+++ b/Module.lua
@@ -113,17 +113,6 @@ function Module:clone(...)
return clone
end
-local function recursiveType(param, type_str)
- if torch.type(param) == 'table' then
- for i = 1, #param do
- param[i] = recursiveType(param[i], type_str)
- end
- elseif torch.isTensor(param) then
- param = param:type(type_str)
- end
- return param
-end
-
function Module:type(type)
assert(type, 'Module: must provide a type to convert to')
-- find all tensors and convert them
@@ -132,7 +121,7 @@ function Module:type(type)
-- are table's of tensors. To be general we need to recursively
-- cast fields that may be nested tables.
if key ~= 'modules' then
- self[key] = recursiveType(self[key], type)
+ self[key] = nn._utils.recursiveType(param, type)
end
end
-- find submodules in classic containers 'modules'