Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarco Scoffier <github@metm.org>2011-10-17 18:23:35 +0400
committerMarco Scoffier <github@metm.org>2011-10-17 18:23:35 +0400
commit4d51e8504f3e6b7250b572559ceb840a4f7845fa (patch)
treefdc6334635499c4fb57a394e75f3bc18b4b54e7e
parent38eeb494818ddeffc1a4322e080bb08680b8c6c5 (diff)
parent062a4c9c0170076a116b6784339046593f2493b1 (diff)
Merge branch 'master' of github.com:clementfarabet/lua---nnx
-rw-r--r--BatchOptimization.lua1
-rw-r--r--DiagHessian.lua2
-rw-r--r--SGDOptimization.lua24
3 files changed, 21 insertions, 6 deletions
diff --git a/BatchOptimization.lua b/BatchOptimization.lua
index 5cb5230..1e83bbb 100644
--- a/BatchOptimization.lua
+++ b/BatchOptimization.lua
@@ -99,7 +99,6 @@ function Batch:forward_sequential(inputs, targets, options)
self.output = self.output/batchsize
else -- minibatch is assumed to be a BatchSize x ... tensor
-
-- estimate f
local output = self.module:forward(inputs)
self.output = self.criterion:forward(output, targets)
diff --git a/DiagHessian.lua b/DiagHessian.lua
index 40b5aa9..dfcdcaf 100644
--- a/DiagHessian.lua
+++ b/DiagHessian.lua
@@ -17,7 +17,7 @@ function nn.Criterion.backwardDiagHessian(self, input, diagHessianOutput)
return self.diagHessianInput
end
--- MSECriterion
+ -- MSECriterion
function nn.MSECriterion.backwardDiagHessian(self, input, diagHessianOutput)
self.diagHessianInput = self.diagHessianInput or input.new()
self.diagHessianInput:resizeAs(input):fill(1)
diff --git a/SGDOptimization.lua b/SGDOptimization.lua
index 419bdc7..e26c6ed 100644
--- a/SGDOptimization.lua
+++ b/SGDOptimization.lua
@@ -47,7 +47,7 @@ function SGD:optimize()
if self.learningRates then
-- we are using diagHessian and have individual learningRates
self.deltaParameters = self.deltaParameters or
- self.parameters.new():resizeAs(self.currentGradParameters)
+ torch.Tensor():typeAs(self.parameters):resizeAs(self.currentGradParameters)
self.deltaParameters:copy(self.learningRates):cmul(self.currentGradParameters)
self.parameters:add(-learningRate, self.deltaParameters)
else
@@ -80,12 +80,17 @@ function SGD:diagHessian(inputs, targets)
if not self.learningRates then
-- do initialization
self.diagHessianEpsilon = self.diagHessianEpslion or 1e-3
- self.learningRates = self.parameters.new():resizeAs(self.parameters):fill(1)
+ self.learningRates = torch.Tensor():typeAs(self.parameters):resizeAs(self.parameters):fill(1)
self.module:initDiagHessianParameters()
self.diagHessianParameters =
nnx.flattenParameters(nnx.getDiagHessianParameters(self.module))
end
+ -- reset gradients
+ self.gradParameters:zero()
+ -- reset Hessian Parameterns
self.diagHessianParameters:zero()
+ -- reset individual learningRates
+ self.learningRates:fill(1)
-- estimate diag hessian over dataset
if type(inputs) == 'table' then -- slow
for i = 1,#inputs do
@@ -106,11 +111,22 @@ function SGD:diagHessian(inputs, targets)
end
-- protect diag hessian (the proper way of doing it is the commented code,
-- but for speed reasons, the uncommented code just works)
- -- self.diagHessianParameters:apply(function(x) return math.max(x, diagHessianEpsilon) end)
- self.diagHessianParameters:add(self.diagHessianEpsilon)
+ self.diagHessianParameters:apply(
+ function(x)
+ return math.max(x, self.diagHessianEpsilon)
+ end)
+ --self.diagHessianParameters:add(self.diagHessianEpsilon)
-- now learning rates are obtained like this:
self.learningRates:cdiv(self.diagHessianParameters)
+ print('<diagHessian>')
+ print(' + norm of dhP: '..self.diagHessianParameters:norm()..
+ ' norm of LR: '..self.learningRates:norm())
+ print(' + max dhP : '..self.diagHessianParameters:max() ..
+ ' max LR: '..self.learningRates:max())
+ print(' + min dhp: '.. self.diagHessianParameters:min() ..
+ ' min LR: '..self.learningRates:min())
+ -- self.learningRates:div(self.learningRates:norm())
end
function SGD:optimalLearningRate(inputs, targets)