diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-23 04:15:32 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-23 04:15:32 +0400 |
commit | a7f7ad25a56fbb2787a9218748750bfcabe88e45 (patch) | |
tree | 031c6743bdfc7efb0d69dbbe61bdae461a1257e7 /init.lua | |
parent | 4ceab27dcac38753d60fa4eea4ff0c09be6eda25 (diff) |
Added getGradParameters method.
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 29 |
1 files changed, 20 insertions, 9 deletions
@@ -140,19 +140,18 @@ function nnx.empty(module) end end -local function getParameters(module, holder) +local function get(module, holder, params) -- find submodules in classic containers 'modules' if module.modules then for _,module in ipairs(module.modules) do - getParameters(module, holder) + get(module, holder, params) end else - -- store weight and bias parameters - if module.weight then - table.insert(holder, module.weight) - end - if module.bias then - table.insert(holder, module.bias) + -- find parameters and store them + for _,param in ipairs(params) do + if module[param] then + table.insert(holder, module[param]) + end end end end @@ -163,7 +162,19 @@ function nnx.getParameters(...) -- call recursive call local modules = {...} for _,module in ipairs(modules) do - getParameters(module, holder) + get(module, holder, {'weight', 'bias'}) + end + -- return all parameters found + return holder +end + +function nnx.getGradParameters(...) + -- to hold all parameters found + holder = {} + -- call recursive call + local modules = {...} + for _,module in ipairs(modules) do + get(module, holder, {'gradWeight', 'gradBias'}) end -- return all parameters found return holder |