Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--init.lua29
1 files changed, 20 insertions, 9 deletions
diff --git a/init.lua b/init.lua
index 2a85955..78948c5 100644
--- a/init.lua
+++ b/init.lua
@@ -140,19 +140,18 @@ function nnx.empty(module)
end
end
-local function getParameters(module, holder)
+local function get(module, holder, params)
-- find submodules in classic containers 'modules'
if module.modules then
for _,module in ipairs(module.modules) do
- getParameters(module, holder)
+ get(module, holder, params)
end
else
- -- store weight and bias parameters
- if module.weight then
- table.insert(holder, module.weight)
- end
- if module.bias then
- table.insert(holder, module.bias)
+ -- find parameters and store them
+ for _,param in ipairs(params) do
+ if module[param] then
+ table.insert(holder, module[param])
+ end
end
end
end
@@ -163,7 +162,19 @@ function nnx.getParameters(...)
-- call recursive call
local modules = {...}
for _,module in ipairs(modules) do
- getParameters(module, holder)
+ get(module, holder, {'weight', 'bias'})
+ end
+ -- return all parameters found
+ return holder
+end
+
+function nnx.getGradParameters(...)
+ -- to hold all parameters found
+ holder = {}
+ -- call recursive call
+ local modules = {...}
+ for _,module in ipairs(modules) do
+ get(module, holder, {'gradWeight', 'gradBias'})
end
-- return all parameters found
return holder