Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarco Scoffier <github@metm.org>2011-10-16 10:20:05 +0400
committerMarco Scoffier <github@metm.org>2011-10-16 10:20:05 +0400
commit062a4c9c0170076a116b6784339046593f2493b1 (patch)
tree912d48d80a4f74c80c90883974d24f26e2285c5f
parentd152adae01e2b11529fa7201cf66e13e537ff7c9 (diff)
reset gradParameteres when computing Hessian
-rw-r--r--BatchOptimization.lua1
-rw-r--r--DiagHessian.lua2
-rw-r--r--SGDOptimization.lua4
3 files changed, 5 insertions, 2 deletions
diff --git a/BatchOptimization.lua b/BatchOptimization.lua
index 5cb5230..1e83bbb 100644
--- a/BatchOptimization.lua
+++ b/BatchOptimization.lua
@@ -99,7 +99,6 @@ function Batch:forward_sequential(inputs, targets, options)
self.output = self.output/batchsize
else -- minibatch is assumed to be a BatchSize x ... tensor
-
-- estimate f
local output = self.module:forward(inputs)
self.output = self.criterion:forward(output, targets)
diff --git a/DiagHessian.lua b/DiagHessian.lua
index 40b5aa9..dfcdcaf 100644
--- a/DiagHessian.lua
+++ b/DiagHessian.lua
@@ -17,7 +17,7 @@ function nn.Criterion.backwardDiagHessian(self, input, diagHessianOutput)
return self.diagHessianInput
end
--- MSECriterion
+ -- MSECriterion
function nn.MSECriterion.backwardDiagHessian(self, input, diagHessianOutput)
self.diagHessianInput = self.diagHessianInput or input.new()
self.diagHessianInput:resizeAs(input):fill(1)
diff --git a/SGDOptimization.lua b/SGDOptimization.lua
index 3837950..e26c6ed 100644
--- a/SGDOptimization.lua
+++ b/SGDOptimization.lua
@@ -85,7 +85,11 @@ function SGD:diagHessian(inputs, targets)
self.diagHessianParameters =
nnx.flattenParameters(nnx.getDiagHessianParameters(self.module))
end
+ -- reset gradients
+ self.gradParameters:zero()
+ -- reset Hessian Parameterns
self.diagHessianParameters:zero()
+ -- reset individual learningRates
self.learningRates:fill(1)
-- estimate diag hessian over dataset
if type(inputs) == 'table' then -- slow