diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-07-21 21:05:40 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-07-21 21:05:40 +0300 |
commit | 83d952d47cafa710bb5cf95be797f4f2352baa0c (patch) | |
tree | 1e13043d5c3c90fd82077cfaa220f5b745e20b27 | |
parent | e24fd8550d366668473d0b5a89c00546d2145c81 (diff) | |
parent | 7b32fd239008067287f5225dd6ee9798065d8aa1 (diff) |
Merge pull request #122 from Cadene/master
Add LearningRateDecay to Adam
-rw-r--r-- | adam.lua | 9 | ||||
-rw-r--r-- | doc/algos.md | 1 |
2 files changed, 8 insertions, 2 deletions
@@ -7,6 +7,7 @@ ARGS: - 'x' : the initial point - 'config` : a table with configuration parameters for the optimizer - 'config.learningRate' : learning rate +- `config.learningRateDecay` : learning rate decay - 'config.beta1' : first moment coefficient - 'config.beta2' : second moment coefficient - 'config.epsilon' : for numerical stability @@ -25,6 +26,7 @@ function optim.adam(opfunc, x, config, state) local config = config or {} local state = state or config local lr = config.learningRate or 0.001 + local lrd = config.learningRateDecay or 0 local beta1 = config.beta1 or 0.9 local beta2 = config.beta2 or 0.999 @@ -48,6 +50,9 @@ function optim.adam(opfunc, x, config, state) -- A tmp tensor to hold the sqrt(v) + epsilon state.denom = state.denom or x.new(dfdx:size()):zero() + -- (3) learning rate decay (annealing) + local clr = lr / (1 + state.t*lrd) + state.t = state.t + 1 -- Decay the first and second moment running average coefficient @@ -58,8 +63,8 @@ function optim.adam(opfunc, x, config, state) local biasCorrection1 = 1 - beta1^state.t local biasCorrection2 = 1 - beta2^state.t - local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1 - -- (3) update x + local stepSize = clr * math.sqrt(biasCorrection2)/biasCorrection1 + -- (4) update x x:addcdiv(-stepSize, state.m, state.denom) -- return x*, f(x) before optimization diff --git a/doc/algos.md b/doc/algos.md index a671420..a3ce681 100644 --- a/doc/algos.md +++ b/doc/algos.md @@ -200,6 +200,7 @@ Arguments: * `x`: the initial point * `config`: a table with configuration parameters for the optimizer * `config.learningRate`: learning rate + * `config.learningRateDecay`: learning rate decay * `config.beta1`: first moment coefficient * `config.beta2`: second moment coefficient * `config.epsilon`: for numerical stability |