Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkarpathy <andrej.karpathy@gmail.com>2015-04-16 02:37:43 +0300
committerkarpathy <andrej.karpathy@gmail.com>2015-04-16 02:37:43 +0300
commitf25684174a6101f3b4521e74bd0718ebd848c5cd (patch)
treed5be412b82008647ea3421d6b5118d2b3216d4b7
parente97f00706f5ac4b38909ae64e2afeaddd6409cb0 (diff)
fixing rmsprop bug
-rw-r--r--rmsprop.lua20
1 files changed, 7 insertions, 13 deletions
diff --git a/rmsprop.lua b/rmsprop.lua
index b6baf96..bf771e9 100644
--- a/rmsprop.lua
+++ b/rmsprop.lua
@@ -9,9 +9,6 @@ ARGS:
- 'config.learningRate' : learning rate
- 'config.alpha' : smoothing constant
- 'config.epsilon' : value with which to inistialise m
-- 'config.epsilon2' : stablisation to prevent mean square going to zero
-- 'config.max_gain' : stabilisation to prevent lr multiplier exploding
-- 'config.min_gain' : stabilisation to prevent lr multiplier exploding
- 'state = {m, dfdx_sq}' : a table describing the state of the optimizer; after each
call the state is modified
@@ -25,27 +22,24 @@ function optim.rmsprop(opfunc, x, config, state)
-- (0) get/update state
local config = config or {}
local state = state or config
- local lr = config.learningRate or 1e-4
- local alpha = config.alpha or 0.998
+ local lr = config.learningRate or 1e-2
+ local alpha = config.alpha or 0.95
local epsilon = config.epsilon or 1e-8
- local epsilon2 = config.epsilon2 or 1e-8
- local max_gain = config.max_gain or 1000
- local min_gain = config.min_gain or 1e-8
-- (1) evaluate f(x) and df/dx
local fx, dfdx = opfunc(x)
-- (2) initialize mean square values and square gradient storage
- state.m = state.m or torch.Tensor():typeAs(dfdx):resizeAs(dfdx):fill(epsilon)
+ state.m = state.m or torch.Tensor():typeAs(dfdx):resizeAs(dfdx):zero()
state.tmp = state.tmp or x.new(dfdx:size()):zero()
- -- (3) calculate new mean squared values
+ -- (3) calculate new (leaky) mean squared values
state.m:mul(alpha)
- state.m:addcmul(1.0-alpha,dfdx,dfdx):add(epsilon2)
+ state.m:addcmul(1.0-alpha,dfdx,dfdx)
-- (4) perform update
- state.tmp:copy(state.m):pow(-0.5):clamp(min_gain, max_gain)
- x:add(-lr, state.tmp)
+ state.tmp:copy(state.m):sqrt()
+ x:addcdiv(-lr, dfdx, state.tmp:add(epsilon))
-- return x*, f(x) before optimization
return x, {fx}, state.tmp