Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-04-12 21:27:12 +0300
committerSoumith Chintala <soumith@gmail.com>2016-04-12 21:27:12 +0300
commitaeeacbd628c5464db577c229597986e6f59501aa (patch)
tree252427196bd48d5e5a40a3ec77f9395d5691ba41
parent5e11b9fc448f3612c93429dbd125c8aa0862d3d5 (diff)
parentb7cd22f74d078df025807c5772d088b56c721f25 (diff)
Merge pull request #102 from colesbury/rmsprop
Add support for weight decay to rmsprop
-rw-r--r--rmsprop.lua13
1 files changed, 10 insertions, 3 deletions
diff --git a/rmsprop.lua b/rmsprop.lua
index ecb33f1..8947b18 100644
--- a/rmsprop.lua
+++ b/rmsprop.lua
@@ -9,6 +9,7 @@ ARGS:
- 'config.learningRate' : learning rate
- 'config.alpha' : smoothing constant
- 'config.epsilon' : value with which to initialise m
+- 'config.weightDecay' : weight decay
- 'state' : a table describing the state of the optimizer;
after each call the state is modified
- 'state.m' : leaky sum of squares of parameter gradients,
@@ -27,21 +28,27 @@ function optim.rmsprop(opfunc, x, config, state)
local lr = config.learningRate or 1e-2
local alpha = config.alpha or 0.99
local epsilon = config.epsilon or 1e-8
+ local wd = config.weightDecay or 0
-- (1) evaluate f(x) and df/dx
local fx, dfdx = opfunc(x)
- -- (2) initialize mean square values and square gradient storage
+ -- (2) weight decay
+ if wd ~= 0 then
+ dfdx:add(wd, x)
+ end
+
+ -- (3) initialize mean square values and square gradient storage
if not state.m then
state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):zero()
state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx)
end
- -- (3) calculate new (leaky) mean squared values
+ -- (4) calculate new (leaky) mean squared values
state.m:mul(alpha)
state.m:addcmul(1.0-alpha, dfdx, dfdx)
- -- (4) perform update
+ -- (5) perform update
state.tmp:sqrt(state.m):add(epsilon)
x:addcdiv(-lr, dfdx, state.tmp)