diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-05-19 18:46:34 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-05-19 18:46:34 +0300 |
commit | efdc2404a843c9e16535e633df16fceb28c88bd8 (patch) | |
tree | ba7d8810271da3c646a87215d3d1f0a252ce71a2 | |
parent | e9af33beb9f7b03f7cd3df7ee1b025dfd28ad7f9 (diff) | |
parent | 62d97dde221fca236232eaabe879d3eed12de0da (diff) |
Merge pull request #108 from gcheron/adadelta-wdec
add weight decay support to adadelta
-rw-r--r-- | adadelta.lua | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/adadelta.lua b/adadelta.lua index 5848d17..7cc058d 100644 --- a/adadelta.lua +++ b/adadelta.lua @@ -7,6 +7,7 @@ ARGS: - `config` : a table of hyper-parameters - `config.rho` : interpolation parameter - `config.eps` : for numerical stability +- `config.weightDecay` : weight decay - `state` : a table describing the state of the optimizer; after each call the state is modified - `state.paramVariance` : vector of temporal variances of parameters @@ -24,11 +25,17 @@ function optim.adadelta(opfunc, x, config, state) local state = state or config local rho = config.rho or 0.9 local eps = config.eps or 1e-6 + local wd = config.weightDecay or 0 state.evalCounter = state.evalCounter or 0 -- (1) evaluate f(x) and df/dx local fx,dfdx = opfunc(x) - -- (2) parameter update + -- (2) weight decay + if wd ~= 0 then + dfdx:add(wd, x) + end + + -- (3) parameter update if not state.paramVariance then state.paramVariance = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() state.paramStd = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() @@ -40,7 +47,7 @@ function optim.adadelta(opfunc, x, config, state) state.delta:resizeAs(state.paramVariance):copy(state.accDelta):add(eps):sqrt():cdiv(state.paramStd):cmul(dfdx) x:add(-1, state.delta) state.accDelta:mul(rho):addcmul(1-rho, state.delta, state.delta) - -- (3) update evaluation counter + -- (4) update evaluation counter state.evalCounter = state.evalCounter + 1 -- return x*, f(x) before optimization |