diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2013-08-19 00:19:52 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2013-08-19 00:19:52 +0400 |
commit | c386efdc1dab84854074e0874f274e34a670573e (patch) | |
tree | 20ff8acb2458f77b88543291efdc47ce80b0c3b5 | |
parent | df10803cd6a5528d83a54af3f1a7dcf65fcef7fd (diff) | |
parent | 26ddb8a931421fd7e0feb7bc9e5406720d00cca4 (diff) |
Merge pull request #2 from fidlej/topic_nesterov
Allowed to enable Nesterov momentum
-rw-r--r-- | sgd.lua | 10 |
1 files changed, 9 insertions, 1 deletions
@@ -11,6 +11,8 @@ -- state.learningRateDecay : learning rate decay -- state.weightDecay : weight decay -- state.momentum : momentum +-- state.dampening : dampening for momentum +-- state.nesterov : enables Nesterov momentum -- state.learningRates : vector of individual learning rates -- -- RETURN: @@ -28,9 +30,11 @@ function optim.sgd(opfunc, x, config, state) local wd = config.weightDecay or 0 local mom = config.momentum or 0 local damp = config.dampening or mom + local nesterov = config.nesterov or false local lrs = config.learningRates state.evalCounter = state.evalCounter or 0 local nevals = state.evalCounter + assert(not nesterov or (mom > 0 and damp == 0), "Nesterov momentum requires a momentum and zero dampening") -- (1) evaluate f(x) and df/dx local fx,dfdx = opfunc(x) @@ -47,7 +51,11 @@ function optim.sgd(opfunc, x, config, state) else state.dfdx:mul(mom):add(1-damp, dfdx) end - dfdx = state.dfdx + if nesterov then + dfdx:add(mom, state.dfdx) + else + dfdx = state.dfdx + end end -- (4) learning rate decay (annealing) |