diff options
author | karpathy <andrej.karpathy@gmail.com> | 2015-04-16 02:37:43 +0300 |
---|---|---|
committer | karpathy <andrej.karpathy@gmail.com> | 2015-04-16 02:37:43 +0300 |
commit | f25684174a6101f3b4521e74bd0718ebd848c5cd (patch) | |
tree | d5be412b82008647ea3421d6b5118d2b3216d4b7 | |
parent | e97f00706f5ac4b38909ae64e2afeaddd6409cb0 (diff) |
fixing rmsprop bug
-rw-r--r-- | rmsprop.lua | 20 |
1 files changed, 7 insertions, 13 deletions
diff --git a/rmsprop.lua b/rmsprop.lua index b6baf96..bf771e9 100644 --- a/rmsprop.lua +++ b/rmsprop.lua @@ -9,9 +9,6 @@ ARGS: - 'config.learningRate' : learning rate - 'config.alpha' : smoothing constant - 'config.epsilon' : value with which to inistialise m -- 'config.epsilon2' : stablisation to prevent mean square going to zero -- 'config.max_gain' : stabilisation to prevent lr multiplier exploding -- 'config.min_gain' : stabilisation to prevent lr multiplier exploding - 'state = {m, dfdx_sq}' : a table describing the state of the optimizer; after each call the state is modified @@ -25,27 +22,24 @@ function optim.rmsprop(opfunc, x, config, state) -- (0) get/update state local config = config or {} local state = state or config - local lr = config.learningRate or 1e-4 - local alpha = config.alpha or 0.998 + local lr = config.learningRate or 1e-2 + local alpha = config.alpha or 0.95 local epsilon = config.epsilon or 1e-8 - local epsilon2 = config.epsilon2 or 1e-8 - local max_gain = config.max_gain or 1000 - local min_gain = config.min_gain or 1e-8 -- (1) evaluate f(x) and df/dx local fx, dfdx = opfunc(x) -- (2) initialize mean square values and square gradient storage - state.m = state.m or torch.Tensor():typeAs(dfdx):resizeAs(dfdx):fill(epsilon) + state.m = state.m or torch.Tensor():typeAs(dfdx):resizeAs(dfdx):zero() state.tmp = state.tmp or x.new(dfdx:size()):zero() - -- (3) calculate new mean squared values + -- (3) calculate new (leaky) mean squared values state.m:mul(alpha) - state.m:addcmul(1.0-alpha,dfdx,dfdx):add(epsilon2) + state.m:addcmul(1.0-alpha,dfdx,dfdx) -- (4) perform update - state.tmp:copy(state.m):pow(-0.5):clamp(min_gain, max_gain) - x:add(-lr, state.tmp) + state.tmp:copy(state.m):sqrt() + x:addcdiv(-lr, dfdx, state.tmp:add(epsilon)) -- return x*, f(x) before optimization return x, {fx}, state.tmp |