diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-04-12 21:27:12 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-04-12 21:27:12 +0300 |
commit | aeeacbd628c5464db577c229597986e6f59501aa (patch) | |
tree | 252427196bd48d5e5a40a3ec77f9395d5691ba41 | |
parent | 5e11b9fc448f3612c93429dbd125c8aa0862d3d5 (diff) | |
parent | b7cd22f74d078df025807c5772d088b56c721f25 (diff) |
Merge pull request #102 from colesbury/rmsprop
Add support for weight decay to rmsprop
-rw-r--r-- | rmsprop.lua | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/rmsprop.lua b/rmsprop.lua index ecb33f1..8947b18 100644 --- a/rmsprop.lua +++ b/rmsprop.lua @@ -9,6 +9,7 @@ ARGS: - 'config.learningRate' : learning rate - 'config.alpha' : smoothing constant - 'config.epsilon' : value with which to initialise m +- 'config.weightDecay' : weight decay - 'state' : a table describing the state of the optimizer; after each call the state is modified - 'state.m' : leaky sum of squares of parameter gradients, @@ -27,21 +28,27 @@ function optim.rmsprop(opfunc, x, config, state) local lr = config.learningRate or 1e-2 local alpha = config.alpha or 0.99 local epsilon = config.epsilon or 1e-8 + local wd = config.weightDecay or 0 -- (1) evaluate f(x) and df/dx local fx, dfdx = opfunc(x) - -- (2) initialize mean square values and square gradient storage + -- (2) weight decay + if wd ~= 0 then + dfdx:add(wd, x) + end + + -- (3) initialize mean square values and square gradient storage if not state.m then state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx) end - -- (3) calculate new (leaky) mean squared values + -- (4) calculate new (leaky) mean squared values state.m:mul(alpha) state.m:addcmul(1.0-alpha, dfdx, dfdx) - -- (4) perform update + -- (5) perform update state.tmp:sqrt(state.m):add(epsilon) x:addcdiv(-lr, dfdx, state.tmp) |