diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-01-21 11:02:08 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-01-21 11:02:08 +0400 |
commit | 0cbb3dbfb5247d947a2bdd1dfd34e7d44a35f67d (patch) | |
tree | 0b4b35d92fc72565a8ace4407c921644f50b3f08 /lbfgs.lua | |
parent | f9cc2df0f29ba67d3ae3cbcb5026beb11cda02df (diff) |
Started to fixup LBFGS
Diffstat (limited to 'lbfgs.lua')
-rw-r--r-- | lbfgs.lua | 79 |
1 files changed, 52 insertions, 27 deletions
@@ -23,35 +23,48 @@ function optim.lbfgs(opfunc, x, state) -- get/update state local state = state or {} - local maxIter = state.maxIter or 20 - local maxEval = state.maxEval or 40 + local maxIter = tonumber(state.maxIter) or 20 + local maxEval = tonumber(state.maxEval) or 40 local tolFun = state.tolFun or 1e-5 - local tolX = state.tolFun or 1e-9 + local tolX = state.tolX or 1e-9 local nCorrection = state.nCorrection or 100 + local lineSearch = state.lineSearch local c1 = state.lineSearchDecrease or 1e-4 local c2 = state.lineSearchCurvature or 0.9 - local state.funcEval = state.funcEval or 0 + local learningRate = state.learningRate or 1 + local verbose = state.verbose or false + state.funcEval = state.funcEval or 0 - -- lab -> local + -- verbose function + local function verbose(...) + if verbose then print('<optim.lbfgs> ', ...) end + end + + -- import some functions local zeros = lab.zeros local randn = lab.randn + local append = table.insert + local abs = math.abs -- initial step length local t = 1 -- evaluate initial f(x) and df/dx local f,g = opfunc(x) + local f_hist = {f} local currentFuncEval = 1 state.funcEval = state.funcEval + 1 -- check optimality of initial point if g:abs():sum() <= tolFun then -- optimality condition below tolFun + verbose('optimality condition below tolFun') return x,f end -- optimize for a max of maxIter iterations local nIter = 0 + local d,old_dirs,old_stps,Hdiag,g_old while nIter < maxIter do -- keep track of nb of iterations nIter = nIter + 1 @@ -59,8 +72,7 @@ function optim.lbfgs(opfunc, x, state) ------------------------------------------------------------ -- computer gradient descent direction ------------------------------------------------------------ - local d,old_dirs,old_stps,Hdiag,g_old - if i == 1 then + if nIter == 1 then d = -g old_dirs = {zeros(g:size())} old_stps = {zeros(d:size())} @@ -79,14 +91,14 @@ function optim.lbfgs(opfunc, x, state) old_dirs = {} old_stps = {} for i = 2,#prev_old_dirs do - table.insert(old_dirs, prev_old_dirs[i]) - table.insert(old_stps, prev_old_stps[i]) + append(old_dirs, prev_old_dirs[i]) + append(old_stps, prev_old_stps[i]) end end -- store new direction/step - table.insert(old_dirs, s) - table.insert(old_stps, y) + append(old_dirs, s) + append(old_stps, y) -- update scale of initial Hessian approximation Hdiag = ys/(y*y) @@ -94,7 +106,7 @@ function optim.lbfgs(opfunc, x, state) -- compute the approximate (L-BFGS) inverse Hessian -- multiplied by the gradient - local p = g:size() + local p = g:size(1) local k = #old_dirs local ro = {} @@ -110,15 +122,15 @@ function optim.lbfgs(opfunc, x, state) q[k+1] = -g for i = k,1,-1 do - al[i] = ro[i] * old_dirs[i] * q[i+1] - q[i] = q[i+1] - al[i] * old_stps[i] + al[i] = old_dirs[i] * q[i+1] * ro[i] + q[i] = q[i+1] - old_stps[i] * al[i] end -- multiply by initial Hessian - r[1] = Hdiag * q[1] + r[1] = q[1] * Hdiag for i = 1,k do - be[i] = ro[i] * old_stps[i] * r[i] + be[i] = old_stps[i] * r[i] * ro[i] r[i+1] = r[i] + old_dirs[i] * (al[i] - be[i]) end @@ -126,7 +138,7 @@ function optim.lbfgs(opfunc, x, state) d = r[k+1] end g_old = g:clone() - f_old = f:clone() + f_old = f ------------------------------------------------------------ -- compute step length @@ -140,37 +152,50 @@ function optim.lbfgs(opfunc, x, state) end -- reset initial guess for step size - t = 1 + t = learningRate + + -- optional line search: user function + local lsFuncEval = 0 + if lineSearch and type(lineSearch) == 'function' then + -- perform line search, satisfying Wolfe condition + f,g,x,t,lsFuncEval = lineSearch(opfunc,x,t,d,f,g,gtd,c1,c2,tolX) + append(f_hist, f) - -- perform line search, satisfying Wolfe condition - --[t,f,g,lsFuncEval] = WolfeLineSearch(x,t,d,f,g,gtd,c1,c2,LS=4,25,tolX,false,false,1,opfunc) + -- from minFunc: + --[t,f,g,lsFuncEval] = WolfeLineSearch(x,t,d,f,g,gtd,c1,c2,LS=4,25,tolX,false,false,1,opfunc) + else + -- no line search, simply re-evaluate (costly & stupid but needed by check below) + x:add(d*t) + f,g = opfunc(x) + append(f_hist, f) + end -- update func eval currentFuncEval = currentFuncEval + lsFuncEval state.funcEval = state.funcEval + lsFuncEval - - -- update parameters - x = x + d*t ------------------------------------------------------------ -- check conditions ------------------------------------------------------------ if (d*t):abs():sum() <= tolX then -- step size below tolX + verbose('step size below tolX') break end - if (f-f_old):abs() < tolX then + if abs(f-f_old) < tolX then -- function value changing less than tolX + verbose('function value changing less than tolX') break end - if currentFuncEval >= state.maxEval then + if currentFuncEval >= maxEval then -- max nb of function evals + verbose('max nb of function evals') break end end - -- return f(x_new), x_new - return x,f + -- return optimal x, and history of f(x) + return x,f_hist end |