diff options
author | koray kavukcuoglu <koray@kavukcuoglu.org> | 2012-12-04 22:25:33 +0400 |
---|---|---|
committer | koray kavukcuoglu <koray@kavukcuoglu.org> | 2012-12-04 22:25:33 +0400 |
commit | 9e60c67806673932f490c0f4e3ec7cf36cbf0ea7 (patch) | |
tree | 752423ff5a3fca3245db34bbc20b0acc7cc30831 /adagrad.lua | |
parent | 907296ae976acb61df07e691d50d8a1b54693e52 (diff) |
add ADAGRAD optimization for SGD with diagonal weights
Diffstat (limited to 'adagrad.lua')
-rw-r--r-- | adagrad.lua | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/adagrad.lua b/adagrad.lua new file mode 100644 index 0000000..368bcfb --- /dev/null +++ b/adagrad.lua @@ -0,0 +1,48 @@ +---------------------------------------------------------------------- +-- ADAGRAD implementation for SGD +-- +-- ARGS: +-- opfunc : a function that takes a single input (X), the point of +-- evaluation, and returns f(X) and df/dX +-- x : the initial point +-- state : a table describing the state of the optimizer; after each +-- call the state is modified +-- state.learningRate : learning rate +-- state.paramVariance : vector of temporal variances of parameters +-- +-- RETURN: +-- x : the new x vector +-- f(x) : the function, evaluated before the update +-- +-- +function optim.adagrad(opfunc, x, state) + -- (0) get/update state + if state == nil then + print('no state table, ADAGRAD initializing') + end + local lr = state.learningRate or 1e-3 + local lrd = state.learningRateDecay or 0 + state.evalCounter = state.evalCounter or 0 + local nevals = state.evalCounter + + -- (1) evaluate f(x) and df/dx + local fx,dfdx = opfunc(x) + + -- (3) learning rate decay (annealing) + local clr = lr / (1 + nevals*lrd) + + -- (4) parameter update with single or individual learning rates + if not state.paramVariance then + state.paramVariance = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() + state.paramStd = torch.Tensor():typeAs(x):resizeAs(dfdx) + end + state.paramVariance:addcmul(1,dfdx,dfdx) + torch.sqrt(state.paramStd,state.paramVariance) + x:addcdiv(-clr, dfdx,state.paramStd:add(1e-10)) + + -- (5) update evaluation counter + state.evalCounter = state.evalCounter + 1 + + -- return x*, f(x) before optimization + return x,{fx} +end |