Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-01-21 11:02:08 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-01-21 11:02:08 +0400
commit0cbb3dbfb5247d947a2bdd1dfd34e7d44a35f67d (patch)
tree0b4b35d92fc72565a8ace4407c921644f50b3f08 /lbfgs.lua
parentf9cc2df0f29ba67d3ae3cbcb5026beb11cda02df (diff)
Started to fixup LBFGS
Diffstat (limited to 'lbfgs.lua')
-rw-r--r--lbfgs.lua79
1 files changed, 52 insertions, 27 deletions
diff --git a/lbfgs.lua b/lbfgs.lua
index 740a44e..908773f 100644
--- a/lbfgs.lua
+++ b/lbfgs.lua
@@ -23,35 +23,48 @@
function optim.lbfgs(opfunc, x, state)
-- get/update state
local state = state or {}
- local maxIter = state.maxIter or 20
- local maxEval = state.maxEval or 40
+ local maxIter = tonumber(state.maxIter) or 20
+ local maxEval = tonumber(state.maxEval) or 40
local tolFun = state.tolFun or 1e-5
- local tolX = state.tolFun or 1e-9
+ local tolX = state.tolX or 1e-9
local nCorrection = state.nCorrection or 100
+ local lineSearch = state.lineSearch
local c1 = state.lineSearchDecrease or 1e-4
local c2 = state.lineSearchCurvature or 0.9
- local state.funcEval = state.funcEval or 0
+ local learningRate = state.learningRate or 1
+ local verbose = state.verbose or false
+ state.funcEval = state.funcEval or 0
- -- lab -> local
+ -- verbose function
+ local function verbose(...)
+ if verbose then print('<optim.lbfgs> ', ...) end
+ end
+
+ -- import some functions
local zeros = lab.zeros
local randn = lab.randn
+ local append = table.insert
+ local abs = math.abs
-- initial step length
local t = 1
-- evaluate initial f(x) and df/dx
local f,g = opfunc(x)
+ local f_hist = {f}
local currentFuncEval = 1
state.funcEval = state.funcEval + 1
-- check optimality of initial point
if g:abs():sum() <= tolFun then
-- optimality condition below tolFun
+ verbose('optimality condition below tolFun')
return x,f
end
-- optimize for a max of maxIter iterations
local nIter = 0
+ local d,old_dirs,old_stps,Hdiag,g_old
while nIter < maxIter do
-- keep track of nb of iterations
nIter = nIter + 1
@@ -59,8 +72,7 @@ function optim.lbfgs(opfunc, x, state)
------------------------------------------------------------
-- computer gradient descent direction
------------------------------------------------------------
- local d,old_dirs,old_stps,Hdiag,g_old
- if i == 1 then
+ if nIter == 1 then
d = -g
old_dirs = {zeros(g:size())}
old_stps = {zeros(d:size())}
@@ -79,14 +91,14 @@ function optim.lbfgs(opfunc, x, state)
old_dirs = {}
old_stps = {}
for i = 2,#prev_old_dirs do
- table.insert(old_dirs, prev_old_dirs[i])
- table.insert(old_stps, prev_old_stps[i])
+ append(old_dirs, prev_old_dirs[i])
+ append(old_stps, prev_old_stps[i])
end
end
-- store new direction/step
- table.insert(old_dirs, s)
- table.insert(old_stps, y)
+ append(old_dirs, s)
+ append(old_stps, y)
-- update scale of initial Hessian approximation
Hdiag = ys/(y*y)
@@ -94,7 +106,7 @@ function optim.lbfgs(opfunc, x, state)
-- compute the approximate (L-BFGS) inverse Hessian
-- multiplied by the gradient
- local p = g:size()
+ local p = g:size(1)
local k = #old_dirs
local ro = {}
@@ -110,15 +122,15 @@ function optim.lbfgs(opfunc, x, state)
q[k+1] = -g
for i = k,1,-1 do
- al[i] = ro[i] * old_dirs[i] * q[i+1]
- q[i] = q[i+1] - al[i] * old_stps[i]
+ al[i] = old_dirs[i] * q[i+1] * ro[i]
+ q[i] = q[i+1] - old_stps[i] * al[i]
end
-- multiply by initial Hessian
- r[1] = Hdiag * q[1]
+ r[1] = q[1] * Hdiag
for i = 1,k do
- be[i] = ro[i] * old_stps[i] * r[i]
+ be[i] = old_stps[i] * r[i] * ro[i]
r[i+1] = r[i] + old_dirs[i] * (al[i] - be[i])
end
@@ -126,7 +138,7 @@ function optim.lbfgs(opfunc, x, state)
d = r[k+1]
end
g_old = g:clone()
- f_old = f:clone()
+ f_old = f
------------------------------------------------------------
-- compute step length
@@ -140,37 +152,50 @@ function optim.lbfgs(opfunc, x, state)
end
-- reset initial guess for step size
- t = 1
+ t = learningRate
+
+ -- optional line search: user function
+ local lsFuncEval = 0
+ if lineSearch and type(lineSearch) == 'function' then
+ -- perform line search, satisfying Wolfe condition
+ f,g,x,t,lsFuncEval = lineSearch(opfunc,x,t,d,f,g,gtd,c1,c2,tolX)
+ append(f_hist, f)
- -- perform line search, satisfying Wolfe condition
- --[t,f,g,lsFuncEval] = WolfeLineSearch(x,t,d,f,g,gtd,c1,c2,LS=4,25,tolX,false,false,1,opfunc)
+ -- from minFunc:
+ --[t,f,g,lsFuncEval] = WolfeLineSearch(x,t,d,f,g,gtd,c1,c2,LS=4,25,tolX,false,false,1,opfunc)
+ else
+ -- no line search, simply re-evaluate (costly & stupid but needed by check below)
+ x:add(d*t)
+ f,g = opfunc(x)
+ append(f_hist, f)
+ end
-- update func eval
currentFuncEval = currentFuncEval + lsFuncEval
state.funcEval = state.funcEval + lsFuncEval
-
- -- update parameters
- x = x + d*t
------------------------------------------------------------
-- check conditions
------------------------------------------------------------
if (d*t):abs():sum() <= tolX then
-- step size below tolX
+ verbose('step size below tolX')
break
end
- if (f-f_old):abs() < tolX then
+ if abs(f-f_old) < tolX then
-- function value changing less than tolX
+ verbose('function value changing less than tolX')
break
end
- if currentFuncEval >= state.maxEval then
+ if currentFuncEval >= maxEval then
-- max nb of function evals
+ verbose('max nb of function evals')
break
end
end
- -- return f(x_new), x_new
- return x,f
+ -- return optimal x, and history of f(x)
+ return x,f_hist
end