diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-10-09 20:36:55 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-10-09 20:36:55 +0400 |
commit | d59a0af48b94af44a66ba68a9fa0676ee7ca94da (patch) | |
tree | f3e4da27170e26e4ba8fe08b26f74c0461ce9fff /init.lua | |
parent | 7e7c5bdc12ce6f261f20e9701a7ae19741fe3ba6 (diff) |
Added diag hessian support for a few modules.
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 13 |
1 files changed, 13 insertions, 0 deletions
@@ -108,6 +108,7 @@ torch.include('nnx', 'SGDOptimization.lua') torch.include('nnx', 'LBFGSOptimization.lua') torch.include('nnx', 'CGOptimization.lua') torch.include('nnx', 'GeneticSGDOptimization.lua') +torch.include('nnx', 'DiagHessian.lua') -- trainers: torch.include('nnx', 'Trainer.lua') @@ -196,6 +197,18 @@ function nnx.getGradParameters(...) return holder end +function nnx.getDiagHessianParameters(...) + -- to hold all parameters found + local holder = {} + -- call recursive call + local modules = {...} + for _,module in ipairs(modules) do + get(module, holder, {'diagHessianWeight', 'diagHessianBias'}) + end + -- return all parameters found + return holder +end + function nnx.flattenParameters(parameters) -- already flat ? local flat = true |