From c565173bc01c6f4b1f63ef7c4ae3a832854d582a Mon Sep 17 00:00:00 2001 From: louissmit Date: Tue, 3 Feb 2015 13:03:25 +0100 Subject: optim style comments --- adam.lua | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'adam.lua') diff --git a/adam.lua b/adam.lua index 2462164..1140ae9 100644 --- a/adam.lua +++ b/adam.lua @@ -21,6 +21,7 @@ RETURN: ]] function optim.adam(opfunc, x, config, state) + -- (0) get/update state local config = config or {} local state = state or config local lr = config.learningRate or 2e-6 @@ -30,10 +31,10 @@ function optim.adam(opfunc, x, config, state) local epsilon = config.epsilon or 10e-8 local lambda = config.lambda or 10e-8 - -- get parameters + -- (1) evaluate f(x) and df/dx local fx, dfdx = opfunc(x) - state.t = state.t or 1 -- timestep + state.t = state.t or 1 -- evaluation counter state.m = state.m or torch.Tensor():typeAs(dfdx):resizeAs(dfdx):fill(0) -- Initialize first moment vector state.v = state.v or torch.Tensor():typeAs(dfdx):resizeAs(dfdx):fill(0) -- Initialize second moment vector @@ -43,7 +44,8 @@ function optim.adam(opfunc, x, config, state) local update = torch.cmul(state.m, torch.pow(torch.add(torch.pow(state.v, 2), epsilon),-1)) update:mul(lr * torch.sqrt(1-torch.pow((1-beta2),2)) * torch.pow(1-torch.pow((1-beta1),2), -1)) -- compute final update - + + -- (2) update x and evaluation counter x:add(-update) state.t = state.t + 1 -- cgit v1.2.3