From 6534bfd77a7ef5f9dd116d03d35c6dbf9ab0bce4 Mon Sep 17 00:00:00 2001 From: gcheron Date: Fri, 10 Jun 2016 17:24:02 +0200 Subject: add weight decay support to adam --- adam.lua | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/adam.lua b/adam.lua index 89dd793..a6ad588 100644 --- a/adam.lua +++ b/adam.lua @@ -10,6 +10,7 @@ ARGS: - 'config.beta1' : first moment coefficient - 'config.beta2' : second moment coefficient - 'config.epsilon' : for numerical stability +- 'config.weightDecay' : weight decay - 'state' : a table describing the state of the optimizer; after each call the state is modified @@ -28,10 +29,16 @@ function optim.adam(opfunc, x, config, state) local beta1 = config.beta1 or 0.9 local beta2 = config.beta2 or 0.999 local epsilon = config.epsilon or 1e-8 + local wd = config.weightDecay or 0 -- (1) evaluate f(x) and df/dx local fx, dfdx = opfunc(x) + -- (2) weight decay + if wd ~= 0 then + dfdx:add(wd, x) + end + -- Initialization state.t = state.t or 0 -- Exponential moving average of gradient values @@ -52,7 +59,7 @@ function optim.adam(opfunc, x, config, state) local biasCorrection1 = 1 - beta1^state.t local biasCorrection2 = 1 - beta2^state.t local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1 - -- (2) update x + -- (3) update x x:addcdiv(-stepSize, state.m, state.denom) -- return x*, f(x) before optimization -- cgit v1.2.3