diff options
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 36 |
1 files changed, 36 insertions, 0 deletions
@@ -119,6 +119,13 @@ function nnx.empty(module) local type = torch.typename(entry) if type and type:find('^nn.') then nnx.empty(entry) + elseif type(entry) == 'table' then + for i,entry in ipairs(entry) do + local type = torch.typename(entry) + if type and type:find('^nn.') then + nnx.empty(entry) + end + end end end end @@ -132,3 +139,32 @@ function nnx.empty(module) module.gradInput:storage():resize(0) end end + +local function getParameters(module, holder) + -- find submodules in classic containers 'modules' + if module.modules then + for _,module in ipairs(module.modules) do + getParameters(module, holder) + end + else + -- store weight and bias parameters + if module.weight then + table.insert(holder, module.weight) + end + if module.bias then + table.insert(holder, module.bias) + end + end +end + +function nnx.getParameters(...) + -- to hold all parameters found + holder = {} + -- call recursive call + local modules = {...} + for _,module in ipairs(modules) do + getParameters(module, holder) + end + -- return all parameters found + return holder +end |