diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2013-05-21 22:55:05 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2013-05-21 22:55:05 +0400 |
commit | 7f766b070ac9c7c963b718f47b3cc356740fb74a (patch) | |
tree | d0a784509cece48427b981bc8260764740dea191 | |
parent | 0f1143db5a827159d7f016b2ccfad17c5faf4cd6 (diff) |
CRITICAL: weight decay is finally properly implemented!
-rw-r--r-- | sgd.lua | 20 |
1 files changed, 10 insertions, 10 deletions
@@ -34,7 +34,12 @@ function optim.sgd(opfunc, x, config, state) -- (1) evaluate f(x) and df/dx local fx,dfdx = opfunc(x) - -- (2) apply momentum + -- (2) weight decay + if wd ~= 0 then + dfdx:add(wd, x) + end + + -- (3) apply momentum if mom ~= 0 then if not state.dfdx then state.dfdx = torch.Tensor():typeAs(dfdx):resizeAs(dfdx):copy(dfdx) @@ -44,15 +49,10 @@ function optim.sgd(opfunc, x, config, state) dfdx = state.dfdx end - -- (2) weight decay - if wd ~= 0 then - x:add(-wd*lr, x) - end - - -- (3) learning rate decay (annealing) + -- (4) learning rate decay (annealing) local clr = lr / (1 + nevals*lrd) - - -- (4) parameter update with single or individual learning rates + + -- (5) parameter update with single or individual learning rates if lrs then if not state.deltaParameters then state.deltaParameters = torch.Tensor():typeAs(x):resizeAs(dfdx) @@ -63,7 +63,7 @@ function optim.sgd(opfunc, x, config, state) x:add(-clr, dfdx) end - -- (5) update evaluation counter + -- (6) update evaluation counter state.evalCounter = state.evalCounter + 1 -- return x*, f(x) before optimization |