Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2013-08-19 00:19:52 +0400
committerClement Farabet <clement.farabet@gmail.com>2013-08-19 00:19:52 +0400
commitc386efdc1dab84854074e0874f274e34a670573e (patch)
tree20ff8acb2458f77b88543291efdc47ce80b0c3b5
parentdf10803cd6a5528d83a54af3f1a7dcf65fcef7fd (diff)
parent26ddb8a931421fd7e0feb7bc9e5406720d00cca4 (diff)
Merge pull request #2 from fidlej/topic_nesterov
Allowed to enable Nesterov momentum
-rw-r--r--sgd.lua10
1 files changed, 9 insertions, 1 deletions
diff --git a/sgd.lua b/sgd.lua
index 1f14f96..b588635 100644
--- a/sgd.lua
+++ b/sgd.lua
@@ -11,6 +11,8 @@
-- state.learningRateDecay : learning rate decay
-- state.weightDecay : weight decay
-- state.momentum : momentum
+-- state.dampening : dampening for momentum
+-- state.nesterov : enables Nesterov momentum
-- state.learningRates : vector of individual learning rates
--
-- RETURN:
@@ -28,9 +30,11 @@ function optim.sgd(opfunc, x, config, state)
local wd = config.weightDecay or 0
local mom = config.momentum or 0
local damp = config.dampening or mom
+ local nesterov = config.nesterov or false
local lrs = config.learningRates
state.evalCounter = state.evalCounter or 0
local nevals = state.evalCounter
+ assert(not nesterov or (mom > 0 and damp == 0), "Nesterov momentum requires a momentum and zero dampening")
-- (1) evaluate f(x) and df/dx
local fx,dfdx = opfunc(x)
@@ -47,7 +51,11 @@ function optim.sgd(opfunc, x, config, state)
else
state.dfdx:mul(mom):add(1-damp, dfdx)
end
- dfdx = state.dfdx
+ if nesterov then
+ dfdx:add(mom, state.dfdx)
+ else
+ dfdx = state.dfdx
+ end
end
-- (4) learning rate decay (annealing)