Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-08-23 04:02:04 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-08-23 04:02:04 +0400
commit4ceab27dcac38753d60fa4eea4ff0c09be6eda25 (patch)
tree155f4e5fefa51352a883627b90094b400e571b02
parent30308077713d4bea40f661035a429021c0a278ff (diff)
Added a getParameters helper, to retrieve all trainable params of an nn.
-rw-r--r--init.lua36
1 files changed, 36 insertions, 0 deletions
diff --git a/init.lua b/init.lua
index bbe5662..2a85955 100644
--- a/init.lua
+++ b/init.lua
@@ -119,6 +119,13 @@ function nnx.empty(module)
local type = torch.typename(entry)
if type and type:find('^nn.') then
nnx.empty(entry)
+ elseif type(entry) == 'table' then
+ for i,entry in ipairs(entry) do
+ local type = torch.typename(entry)
+ if type and type:find('^nn.') then
+ nnx.empty(entry)
+ end
+ end
end
end
end
@@ -132,3 +139,32 @@ function nnx.empty(module)
module.gradInput:storage():resize(0)
end
end
+
+local function getParameters(module, holder)
+ -- find submodules in classic containers 'modules'
+ if module.modules then
+ for _,module in ipairs(module.modules) do
+ getParameters(module, holder)
+ end
+ else
+ -- store weight and bias parameters
+ if module.weight then
+ table.insert(holder, module.weight)
+ end
+ if module.bias then
+ table.insert(holder, module.bias)
+ end
+ end
+end
+
+function nnx.getParameters(...)
+ -- to hold all parameters found
+ holder = {}
+ -- call recursive call
+ local modules = {...}
+ for _,module in ipairs(modules) do
+ getParameters(module, holder)
+ end
+ -- return all parameters found
+ return holder
+end