diff options
-rw-r--r-- | init.lua | 29 |
1 files changed, 20 insertions, 9 deletions
@@ -140,19 +140,18 @@ function nnx.empty(module) end end -local function getParameters(module, holder) +local function get(module, holder, params) -- find submodules in classic containers 'modules' if module.modules then for _,module in ipairs(module.modules) do - getParameters(module, holder) + get(module, holder, params) end else - -- store weight and bias parameters - if module.weight then - table.insert(holder, module.weight) - end - if module.bias then - table.insert(holder, module.bias) + -- find parameters and store them + for _,param in ipairs(params) do + if module[param] then + table.insert(holder, module[param]) + end end end end @@ -163,7 +162,19 @@ function nnx.getParameters(...) -- call recursive call local modules = {...} for _,module in ipairs(modules) do - getParameters(module, holder) + get(module, holder, {'weight', 'bias'}) + end + -- return all parameters found + return holder +end + +function nnx.getGradParameters(...) + -- to hold all parameters found + holder = {} + -- call recursive call + local modules = {...} + for _,module in ipairs(modules) do + get(module, holder, {'gradWeight', 'gradBias'}) end -- return all parameters found return holder |