From a7f7ad25a56fbb2787a9218748750bfcabe88e45 Mon Sep 17 00:00:00 2001 From: Clement Farabet Date: Mon, 22 Aug 2011 20:15:32 -0400 Subject: Added getGradParameters method. --- init.lua | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) (limited to 'init.lua') diff --git a/init.lua b/init.lua index 2a85955..78948c5 100644 --- a/init.lua +++ b/init.lua @@ -140,19 +140,18 @@ function nnx.empty(module) end end -local function getParameters(module, holder) +local function get(module, holder, params) -- find submodules in classic containers 'modules' if module.modules then for _,module in ipairs(module.modules) do - getParameters(module, holder) + get(module, holder, params) end else - -- store weight and bias parameters - if module.weight then - table.insert(holder, module.weight) - end - if module.bias then - table.insert(holder, module.bias) + -- find parameters and store them + for _,param in ipairs(params) do + if module[param] then + table.insert(holder, module[param]) + end end end end @@ -163,7 +162,19 @@ function nnx.getParameters(...) -- call recursive call local modules = {...} for _,module in ipairs(modules) do - getParameters(module, holder) + get(module, holder, {'weight', 'bias'}) + end + -- return all parameters found + return holder +end + +function nnx.getGradParameters(...) + -- to hold all parameters found + holder = {} + -- call recursive call + local modules = {...} + for _,module in ipairs(modules) do + get(module, holder, {'gradWeight', 'gradBias'}) end -- return all parameters found return holder -- cgit v1.2.3