diff options
author | Will Williams <gwillwill@gmail.com> | 2015-03-01 04:35:30 +0300 |
---|---|---|
committer | Will Williams <willw@cantabresearch.com> | 2015-03-01 04:37:00 +0300 |
commit | 5d8da2a49082353882ec0e0025bdfe676a28f379 (patch) | |
tree | 61f2bcfa00df36dea1a114d6538e26a3678b388b | |
parent | a8fa1349caa01bc269ff56b738dcb05d5fba7951 (diff) |
removed clone and now only initialise grad squared on first call to function
-rw-r--r-- | rmsprop.lua | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/rmsprop.lua b/rmsprop.lua index c3e3fe8..c8c5ddd 100644 --- a/rmsprop.lua +++ b/rmsprop.lua @@ -12,7 +12,7 @@ ARGS: - 'config.epsilon2' : stablisation to prevent mean square going to zero - 'config.max_gain' : stabilisation to prevent lr multiplier exploding - 'config.min_gain' : stabilisation to prevent lr multiplier exploding -- 'state = {m}' : a table describing the state of the optimizer; after each +- 'state = {m, dfdx_sq}' : a table describing the state of the optimizer; after each call the state is modified RETURN: @@ -35,14 +35,14 @@ function optim.rmsprop(opfunc, x, config, state) -- (1) evaluate f(x) and df/dx local fx, dfdx = opfunc(x) - -- (2) initialize mean square values + -- (2) initialize mean square values and square gradient storage state.m = state.m or torch.Tensor():typeAs(dfdx):resizeAs(dfdx):fill(epsilon) + state.dfdx_sq = state.dfdx_sq or torch.Tensor():typeAs(dfdx):resizeAs(dfdx) - -- (3) calculate new mean squared value - local dfdx_sq = dfdx:clone() - local dfdx_sq = dfdx_sq:cmul(dfdx_sq) + -- (3) calculate new mean squared values + torch.cmul(state.dfdx_sq, dfdx, dfdx) state.m:mul(alpha) - state.m:add(dfdx_sq:mul(1.0-alpha)) + state.m:add(state.dfdx_sq:mul(1.0-alpha)) state.m:add(epsilon2) -- (4) perform update |