Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-05-19 18:46:34 +0300
committerSoumith Chintala <soumith@gmail.com>2016-05-19 18:46:34 +0300
commitefdc2404a843c9e16535e633df16fceb28c88bd8 (patch)
treeba7d8810271da3c646a87215d3d1f0a252ce71a2
parente9af33beb9f7b03f7cd3df7ee1b025dfd28ad7f9 (diff)
parent62d97dde221fca236232eaabe879d3eed12de0da (diff)
Merge pull request #108 from gcheron/adadelta-wdec
add weight decay support to adadelta
-rw-r--r--adadelta.lua11
1 files changed, 9 insertions, 2 deletions
diff --git a/adadelta.lua b/adadelta.lua
index 5848d17..7cc058d 100644
--- a/adadelta.lua
+++ b/adadelta.lua
@@ -7,6 +7,7 @@ ARGS:
- `config` : a table of hyper-parameters
- `config.rho` : interpolation parameter
- `config.eps` : for numerical stability
+- `config.weightDecay` : weight decay
- `state` : a table describing the state of the optimizer; after each
call the state is modified
- `state.paramVariance` : vector of temporal variances of parameters
@@ -24,11 +25,17 @@ function optim.adadelta(opfunc, x, config, state)
local state = state or config
local rho = config.rho or 0.9
local eps = config.eps or 1e-6
+ local wd = config.weightDecay or 0
state.evalCounter = state.evalCounter or 0
-- (1) evaluate f(x) and df/dx
local fx,dfdx = opfunc(x)
- -- (2) parameter update
+ -- (2) weight decay
+ if wd ~= 0 then
+ dfdx:add(wd, x)
+ end
+
+ -- (3) parameter update
if not state.paramVariance then
state.paramVariance = torch.Tensor():typeAs(x):resizeAs(dfdx):zero()
state.paramStd = torch.Tensor():typeAs(x):resizeAs(dfdx):zero()
@@ -40,7 +47,7 @@ function optim.adadelta(opfunc, x, config, state)
state.delta:resizeAs(state.paramVariance):copy(state.accDelta):add(eps):sqrt():cdiv(state.paramStd):cmul(dfdx)
x:add(-1, state.delta)
state.accDelta:mul(rho):addcmul(1-rho, state.delta, state.delta)
- -- (3) update evaluation counter
+ -- (4) update evaluation counter
state.evalCounter = state.evalCounter + 1
-- return x*, f(x) before optimization