Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-10-09 20:36:55 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-10-09 20:36:55 +0400
commitd59a0af48b94af44a66ba68a9fa0676ee7ca94da (patch)
treef3e4da27170e26e4ba8fe08b26f74c0461ce9fff /init.lua
parent7e7c5bdc12ce6f261f20e9701a7ae19741fe3ba6 (diff)
Added diag hessian support for a few modules.
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua13
1 files changed, 13 insertions, 0 deletions
diff --git a/init.lua b/init.lua
index aff1468..5f3daf2 100644
--- a/init.lua
+++ b/init.lua
@@ -108,6 +108,7 @@ torch.include('nnx', 'SGDOptimization.lua')
torch.include('nnx', 'LBFGSOptimization.lua')
torch.include('nnx', 'CGOptimization.lua')
torch.include('nnx', 'GeneticSGDOptimization.lua')
+torch.include('nnx', 'DiagHessian.lua')
-- trainers:
torch.include('nnx', 'Trainer.lua')
@@ -196,6 +197,18 @@ function nnx.getGradParameters(...)
return holder
end
+function nnx.getDiagHessianParameters(...)
+ -- to hold all parameters found
+ local holder = {}
+ -- call recursive call
+ local modules = {...}
+ for _,module in ipairs(modules) do
+ get(module, holder, {'diagHessianWeight', 'diagHessianBias'})
+ end
+ -- return all parameters found
+ return holder
+end
+
function nnx.flattenParameters(parameters)
-- already flat ?
local flat = true