Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-08-23 04:15:32 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-08-23 04:15:32 +0400
commita7f7ad25a56fbb2787a9218748750bfcabe88e45 (patch)
tree031c6743bdfc7efb0d69dbbe61bdae461a1257e7 /init.lua
parent4ceab27dcac38753d60fa4eea4ff0c09be6eda25 (diff)
Added getGradParameters method.
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua29
1 files changed, 20 insertions, 9 deletions
diff --git a/init.lua b/init.lua
index 2a85955..78948c5 100644
--- a/init.lua
+++ b/init.lua
@@ -140,19 +140,18 @@ function nnx.empty(module)
end
end
-local function getParameters(module, holder)
+local function get(module, holder, params)
-- find submodules in classic containers 'modules'
if module.modules then
for _,module in ipairs(module.modules) do
- getParameters(module, holder)
+ get(module, holder, params)
end
else
- -- store weight and bias parameters
- if module.weight then
- table.insert(holder, module.weight)
- end
- if module.bias then
- table.insert(holder, module.bias)
+ -- find parameters and store them
+ for _,param in ipairs(params) do
+ if module[param] then
+ table.insert(holder, module[param])
+ end
end
end
end
@@ -163,7 +162,19 @@ function nnx.getParameters(...)
-- call recursive call
local modules = {...}
for _,module in ipairs(modules) do
- getParameters(module, holder)
+ get(module, holder, {'weight', 'bias'})
+ end
+ -- return all parameters found
+ return holder
+end
+
+function nnx.getGradParameters(...)
+ -- to hold all parameters found
+ holder = {}
+ -- call recursive call
+ local modules = {...}
+ for _,module in ipairs(modules) do
+ get(module, holder, {'gradWeight', 'gradBias'})
end
-- return all parameters found
return holder