Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-01-18 10:36:21 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-01-18 10:36:21 +0400
commitabbbab1d26c0b639acac108cd7d8e9219a5292ee (patch)
tree8b1d9bdc21d208820642db95e84aa5d952001c27 /asgd.lua
parent2aa6871430f3b9c9a1d6f8aba9d796d7a5e88c3c (diff)
Fixed bugs in SGD and ASGD implementations.
Diffstat (limited to 'asgd.lua')
-rw-r--r--asgd.lua37
1 files changed, 13 insertions, 24 deletions
diff --git a/asgd.lua b/asgd.lua
index b1de638..11b3f92 100644
--- a/asgd.lua
+++ b/asgd.lua
@@ -5,7 +5,7 @@
-- x := (1 - lambda eta_t) x - eta_t df/dx(z,x)
-- a := a + mu_t [ x - a ]
--
--- eta_t = eta_0 / (1 + lambda eta0 t) ^ 0.75
+-- eta_t = eta0 / (1 + lambda eta0 t) ^ 0.75
-- mu_t = 1/max(1,t-t0)
--
-- implements ASGD algoritm as in L.Bottou's sgd-2.0
@@ -16,11 +16,10 @@
-- x : the initial point
-- state : a table describing the state of the optimizer; after each
-- call the state is modified
--- state.eta0/learningRate : learning rate
+-- state.eta0 : learning rate
-- state.lambda : decay term
-- state.alpha : power for eta update
-- state.t0 : point at which to start averaging
--- state.learningRates : vector of individual learning rates
--
-- RETURN:
-- x : the new x vector
@@ -30,37 +29,27 @@
function optim.asgd(opfunc, x, state)
-- (0) get/update state
local state = state or {}
- local eta_t = state.eta_t or eta_0
- local lambda = state.lambda or 1
- local alpha = state.alpha or 0.75
- local t0 = state.t0 or 1e6
- local lrs = state.learningRates
- state.evalCounter = state.evalCounter or 0
- local nevals = state.evalCounters
+ state.eta0 = state.eta0 or 1e-4
+ state.lambda = state.lambda or 1e-4
+ state.alpha = state.alpha or 0.75
+ state.t0 = state.t0 or 1e6
- state.eta_0 = state.eta0 or state.learningRate or 1e-4
- state.mu_0 = state.mu_t or 0
+ -- (hidden state)
+ state.eta_t = state.eta_t or state.eta0
+ state.mu_t = state.mu_t or 1
state.t = state.t or 0
-- (1) evaluate f(x) and df/dx
local fx,dfdx = opfunc(x)
-- (2) decay term
- x:mul(1 - lambda*state.eta_t)
+ x:mul(1 - state.lambda*state.eta_t)
-- (3) update x
- if lrs then
- if not state.deltax then
- state.deltax = torch.Tensor():typeAs(x):resizeAs(dfdx)
- end
- state.deltax:copy(lrs):cmul(dfdx)
- x:add(-state.eta_t, state.deltax)
- else
- x:add(-state.eta_t, state.dfdx)
- end
+ x:add(-state.eta_t, dfdx)
-- (4) averaging
- state.ax = state.a or torch.Tensor():typeAs(x):resizeAs(x):zero()
+ state.ax = state.ax or torch.Tensor():typeAs(x):resizeAs(x):zero()
state.tmp = state.tmp or torch.Tensor():typeAs(state.ax):resizeAs(state.ax)
if state.mu_t ~= 1 then
state.tmp:copy(x)
@@ -72,7 +61,7 @@ function optim.asgd(opfunc, x, state)
-- (5) update eta_t and mu_t
state.t = state.t + 1
- state.eta_t = state.eta0 / math.pow((1 + state.lambda * state.eta0 * state.t), alpha)
+ state.eta_t = state.eta0 / math.pow((1 + state.lambda * state.eta0 * state.t), state.alpha)
state.mu_t = 1 / math.max(1, state.t - state.t0)
-- return f(x_old), x_new, and averaged x