diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-23 04:02:04 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-23 04:02:04 +0400 |
commit | 4ceab27dcac38753d60fa4eea4ff0c09be6eda25 (patch) | |
tree | 155f4e5fefa51352a883627b90094b400e571b02 /init.lua | |
parent | 30308077713d4bea40f661035a429021c0a278ff (diff) |
Added a getParameters helper, to retrieve all trainable params of an nn.
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 36 |
1 files changed, 36 insertions, 0 deletions
@@ -119,6 +119,13 @@ function nnx.empty(module) local type = torch.typename(entry) if type and type:find('^nn.') then nnx.empty(entry) + elseif type(entry) == 'table' then + for i,entry in ipairs(entry) do + local type = torch.typename(entry) + if type and type:find('^nn.') then + nnx.empty(entry) + end + end end end end @@ -132,3 +139,32 @@ function nnx.empty(module) module.gradInput:storage():resize(0) end end + +local function getParameters(module, holder) + -- find submodules in classic containers 'modules' + if module.modules then + for _,module in ipairs(module.modules) do + getParameters(module, holder) + end + else + -- store weight and bias parameters + if module.weight then + table.insert(holder, module.weight) + end + if module.bias then + table.insert(holder, module.bias) + end + end +end + +function nnx.getParameters(...) + -- to hold all parameters found + holder = {} + -- call recursive call + local modules = {...} + for _,module in ipairs(modules) do + getParameters(module, holder) + end + -- return all parameters found + return holder +end |