diff options
Diffstat (limited to 'rmsprop.lua')
-rw-r--r-- | rmsprop.lua | 50 |
1 files changed, 25 insertions, 25 deletions
diff --git a/rmsprop.lua b/rmsprop.lua index 038af21..1eb526d 100644 --- a/rmsprop.lua +++ b/rmsprop.lua @@ -22,36 +22,36 @@ RETURN: ]] function optim.rmsprop(opfunc, x, config, state) - -- (0) get/update state - local config = config or {} - local state = state or config - local lr = config.learningRate or 1e-2 - local alpha = config.alpha or 0.99 - local epsilon = config.epsilon or 1e-8 - local wd = config.weightDecay or 0 - - -- (1) evaluate f(x) and df/dx - local fx, dfdx = opfunc(x) - - -- (2) weight decay - if wd ~= 0 then + -- (0) get/update state + local config = config or {} + local state = state or config + local lr = config.learningRate or 1e-2 + local alpha = config.alpha or 0.99 + local epsilon = config.epsilon or 1e-8 + local wd = config.weightDecay or 0 + + -- (1) evaluate f(x) and df/dx + local fx, dfdx = opfunc(x) + + -- (2) weight decay + if wd ~= 0 then dfdx:add(wd, x) - end + end - -- (3) initialize mean square values and square gradient storage - if not state.m then + -- (3) initialize mean square values and square gradient storage + if not state.m then state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):fill(1) state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx) - end + end - -- (4) calculate new (leaky) mean squared values - state.m:mul(alpha) - state.m:addcmul(1.0-alpha, dfdx, dfdx) + -- (4) calculate new (leaky) mean squared values + state.m:mul(alpha) + state.m:addcmul(1.0-alpha, dfdx, dfdx) - -- (5) perform update - state.tmp:sqrt(state.m):add(epsilon) - x:addcdiv(-lr, dfdx, state.tmp) + -- (5) perform update + state.tmp:sqrt(state.m):add(epsilon) + x:addcdiv(-lr, dfdx, state.tmp) - -- return x*, f(x) before optimization - return x, {fx} + -- return x*, f(x) before optimization + return x, {fx} end |