diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-01-18 10:36:21 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-01-18 10:36:21 +0400 |
commit | abbbab1d26c0b639acac108cd7d8e9219a5292ee (patch) | |
tree | 8b1d9bdc21d208820642db95e84aa5d952001c27 /asgd.lua | |
parent | 2aa6871430f3b9c9a1d6f8aba9d796d7a5e88c3c (diff) |
Fixed bugs in SGD and ASGD implementations.
Diffstat (limited to 'asgd.lua')
-rw-r--r-- | asgd.lua | 37 |
1 files changed, 13 insertions, 24 deletions
@@ -5,7 +5,7 @@ -- x := (1 - lambda eta_t) x - eta_t df/dx(z,x) -- a := a + mu_t [ x - a ] -- --- eta_t = eta_0 / (1 + lambda eta0 t) ^ 0.75 +-- eta_t = eta0 / (1 + lambda eta0 t) ^ 0.75 -- mu_t = 1/max(1,t-t0) -- -- implements ASGD algoritm as in L.Bottou's sgd-2.0 @@ -16,11 +16,10 @@ -- x : the initial point -- state : a table describing the state of the optimizer; after each -- call the state is modified --- state.eta0/learningRate : learning rate +-- state.eta0 : learning rate -- state.lambda : decay term -- state.alpha : power for eta update -- state.t0 : point at which to start averaging --- state.learningRates : vector of individual learning rates -- -- RETURN: -- x : the new x vector @@ -30,37 +29,27 @@ function optim.asgd(opfunc, x, state) -- (0) get/update state local state = state or {} - local eta_t = state.eta_t or eta_0 - local lambda = state.lambda or 1 - local alpha = state.alpha or 0.75 - local t0 = state.t0 or 1e6 - local lrs = state.learningRates - state.evalCounter = state.evalCounter or 0 - local nevals = state.evalCounters + state.eta0 = state.eta0 or 1e-4 + state.lambda = state.lambda or 1e-4 + state.alpha = state.alpha or 0.75 + state.t0 = state.t0 or 1e6 - state.eta_0 = state.eta0 or state.learningRate or 1e-4 - state.mu_0 = state.mu_t or 0 + -- (hidden state) + state.eta_t = state.eta_t or state.eta0 + state.mu_t = state.mu_t or 1 state.t = state.t or 0 -- (1) evaluate f(x) and df/dx local fx,dfdx = opfunc(x) -- (2) decay term - x:mul(1 - lambda*state.eta_t) + x:mul(1 - state.lambda*state.eta_t) -- (3) update x - if lrs then - if not state.deltax then - state.deltax = torch.Tensor():typeAs(x):resizeAs(dfdx) - end - state.deltax:copy(lrs):cmul(dfdx) - x:add(-state.eta_t, state.deltax) - else - x:add(-state.eta_t, state.dfdx) - end + x:add(-state.eta_t, dfdx) -- (4) averaging - state.ax = state.a or torch.Tensor():typeAs(x):resizeAs(x):zero() + state.ax = state.ax or torch.Tensor():typeAs(x):resizeAs(x):zero() state.tmp = state.tmp or torch.Tensor():typeAs(state.ax):resizeAs(state.ax) if state.mu_t ~= 1 then state.tmp:copy(x) @@ -72,7 +61,7 @@ function optim.asgd(opfunc, x, state) -- (5) update eta_t and mu_t state.t = state.t + 1 - state.eta_t = state.eta0 / math.pow((1 + state.lambda * state.eta0 * state.t), alpha) + state.eta_t = state.eta0 / math.pow((1 + state.lambda * state.eta0 * state.t), state.alpha) state.mu_t = 1 / math.max(1, state.t - state.t0) -- return f(x_old), x_new, and averaged x |