diff options
author | Brendan Shillingford <shillingford@google.com> | 2015-09-08 07:18:46 +0300 |
---|---|---|
committer | Brendan Shillingford <shillingford@google.com> | 2015-09-08 07:18:46 +0300 |
commit | 32854d2278c8e63c3fc05229e21eb0e9ee875c4d (patch) | |
tree | fd6e2c46ac3714dcc8b0a9d12701e3d5ce0d4363 | |
parent | a078e1edf26edb5d1ec02798903da7a8d634acb6 (diff) |
Memory usage improvements for L-BFGS
-rw-r--r-- | lbfgs.lua | 94 |
1 files changed, 47 insertions, 47 deletions
@@ -51,14 +51,14 @@ function optim.lbfgs(opfunc, x, config, state) state.nIter = state.nIter or 0 -- verbose function - local function verbose(...) - if isverbose then print('<optim.lbfgs> ', ...) end + local verbose + if isverbose then + verbose = function(...) print('<optim.lbfgs> ', ...) end + else + verbose = function() end end -- import some functions - local zeros = torch.zeros - local randn = torch.randn - local append = table.insert local abs = math.abs local min = math.min @@ -67,9 +67,10 @@ function optim.lbfgs(opfunc, x, config, state) local f_hist = {f} local currentFuncEval = 1 state.funcEval = state.funcEval + 1 + local p = g:size(1) -- check optimality of initial point - state.tmp1 = state.abs_g or zeros(g:size()); local tmp1 = state.tmp1 + state.tmp1 = state.tmp1 or g.new(g:size()):zero(); local tmp1 = state.tmp1 tmp1:copy(g):abs() if tmp1:sum() <= tolFun then -- optimality condition below tolFun @@ -77,6 +78,14 @@ function optim.lbfgs(opfunc, x, config, state) return x,f_hist end + -- reusable buffers for y's and s's, and their histories + state.dir_bufs = state.dir_bufs or g.new(nCorrection+1, p):split(1) + state.stp_bufs = state.stp_bufs or g.new(nCorrection+1, p):split(1) + for i=1,#state.dir_bufs do + state.dir_bufs[i] = state.dir_bufs[i]:squeeze(1) + state.stp_bufs[i] = state.stp_bufs[i]:squeeze(1) + end + -- variables cached in state (for tracing) local d = state.d local t = state.t @@ -103,74 +112,65 @@ function optim.lbfgs(opfunc, x, config, state) Hdiag = 1 else -- do lbfgs update (update memory) - local y = g:clone():add(-1, g_old) -- g - g_old - local s = d:clone():mul(t) -- d*t + local y = table.remove(state.dir_bufs) -- pop + local s = table.remove(state.stp_bufs) + y:copy(g):add(-1, g_old) -- g - g_old + s:copy(d):mul(t) -- d*t local ys = y:dot(s) -- y*s if ys > 1e-10 then -- updating memory if #old_dirs == nCorrection then -- shift history by one (limited-memory) - local prev_old_dirs = old_dirs - local prev_old_stps = old_stps - old_dirs = {} - old_stps = {} - for i = 2,#prev_old_dirs do - append(old_dirs, prev_old_dirs[i]) - append(old_stps, prev_old_stps[i]) - end + local removed1 = table.remove(old_dirs, 1) + local removed2 = table.remove(old_stps, 1) + table.insert(state.dir_bufs, removed1) + table.insert(state.stp_bufs, removed2) end -- store new direction/step - append(old_dirs, s) - append(old_stps, y) + table.insert(old_dirs, s) + table.insert(old_stps, y) -- update scale of initial Hessian approximation Hdiag = ys / y:dot(y) -- (y*y) - - -- cleanup - collectgarbage() + else + -- put y and s back into the buffer pool + table.insert(state.dir_bufs, y) + table.insert(state.stp_bufs, s) end -- compute the approximate (L-BFGS) inverse Hessian -- multiplied by the gradient - local p = g:size(1) local k = #old_dirs - state.ro = state.ro or zeros(nCorrection); local ro = state.ro + -- need to be accessed element-by-element, so don't re-type tensor: + state.ro = state.ro or torch.Tensor(nCorrection); local ro = state.ro for i = 1,k do ro[i] = 1 / old_stps[i]:dot(old_dirs[i]) end - state.q = state.q or zeros(nCorrection+1,p):typeAs(g) - local q = state.q - state.r = state.r or zeros(nCorrection+1,p):typeAs(g) - local r = state.r - state.al = state.al or zeros(nCorrection):typeAs(g) - local al = state.al - state.be = state.be or zeros(nCorrection):typeAs(g) - local be = state.be - - q[k+1] = g:clone():mul(-1) -- -g + -- iteration in L-BFGS loop collapsed to use just one buffer + local q = tmp1 -- reuse tmp1 for the q buffer + -- need to be accessed element-by-element, so don't re-type tensor: + state.al = state.al or torch.zeros(nCorrection) local al = state.al + q:copy(g):mul(-1) -- -g for i = k,1,-1 do - al[i] = old_dirs[i]:dot(q[i+1]) * ro[i] - q[i] = q[i+1] - q[i]:add(-al[i], old_stps[i]) + al[i] = old_dirs[i]:dot(q) * ro[i] + q:add(-al[i], old_stps[i]) end -- multiply by initial Hessian - r[1] = q[1]:clone():mul(Hdiag) -- q[1] * Hdiag - + r = d -- share the same buffer, since we don't need the old d + r:copy(q):mul(Hdiag) -- q[1] * Hdiag for i = 1,k do - be[i] = old_stps[i]:dot(r[i]) * ro[i] - r[i+1] = r[i] - r[i+1]:add((al[i] - be[i]), old_dirs[i]) + local be_i = old_stps[i]:dot(r) * ro[i] + r:add(al[i]-be_i, old_dirs[i]) end - - -- final direction: - d:copy(r[k+1]) + -- final direction is in r/d (same object) end - g_old = g:clone() + g_old = g_old or g:clone() + g_old:copy(g) f_old = f ------------------------------------------------------------ @@ -197,7 +197,7 @@ function optim.lbfgs(opfunc, x, config, state) if lineSearch and type(lineSearch) == 'function' then -- perform line search, using user function f,g,x,t,lsFuncEval = lineSearch(opfunc,x,t,d,f,g,gtd,lineSearchOpts) - append(f_hist, f) + table.insert(f_hist, f) else -- no line search, simply move with fixed-step x:add(t,d) @@ -207,7 +207,7 @@ function optim.lbfgs(opfunc, x, config, state) -- no use to re-evaluate that function here f,g = opfunc(x) lsFuncEval = 1 - append(f_hist, f) + table.insert(f_hist, f) end end |