Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrendan Shillingford <shillingford@google.com>2015-09-08 07:18:46 +0300
committerBrendan Shillingford <shillingford@google.com>2015-09-08 07:18:46 +0300
commit32854d2278c8e63c3fc05229e21eb0e9ee875c4d (patch)
treefd6e2c46ac3714dcc8b0a9d12701e3d5ce0d4363
parenta078e1edf26edb5d1ec02798903da7a8d634acb6 (diff)
Memory usage improvements for L-BFGS
-rw-r--r--lbfgs.lua94
1 files changed, 47 insertions, 47 deletions
diff --git a/lbfgs.lua b/lbfgs.lua
index c8ae057..b0531de 100644
--- a/lbfgs.lua
+++ b/lbfgs.lua
@@ -51,14 +51,14 @@ function optim.lbfgs(opfunc, x, config, state)
state.nIter = state.nIter or 0
-- verbose function
- local function verbose(...)
- if isverbose then print('<optim.lbfgs> ', ...) end
+ local verbose
+ if isverbose then
+ verbose = function(...) print('<optim.lbfgs> ', ...) end
+ else
+ verbose = function() end
end
-- import some functions
- local zeros = torch.zeros
- local randn = torch.randn
- local append = table.insert
local abs = math.abs
local min = math.min
@@ -67,9 +67,10 @@ function optim.lbfgs(opfunc, x, config, state)
local f_hist = {f}
local currentFuncEval = 1
state.funcEval = state.funcEval + 1
+ local p = g:size(1)
-- check optimality of initial point
- state.tmp1 = state.abs_g or zeros(g:size()); local tmp1 = state.tmp1
+ state.tmp1 = state.tmp1 or g.new(g:size()):zero(); local tmp1 = state.tmp1
tmp1:copy(g):abs()
if tmp1:sum() <= tolFun then
-- optimality condition below tolFun
@@ -77,6 +78,14 @@ function optim.lbfgs(opfunc, x, config, state)
return x,f_hist
end
+ -- reusable buffers for y's and s's, and their histories
+ state.dir_bufs = state.dir_bufs or g.new(nCorrection+1, p):split(1)
+ state.stp_bufs = state.stp_bufs or g.new(nCorrection+1, p):split(1)
+ for i=1,#state.dir_bufs do
+ state.dir_bufs[i] = state.dir_bufs[i]:squeeze(1)
+ state.stp_bufs[i] = state.stp_bufs[i]:squeeze(1)
+ end
+
-- variables cached in state (for tracing)
local d = state.d
local t = state.t
@@ -103,74 +112,65 @@ function optim.lbfgs(opfunc, x, config, state)
Hdiag = 1
else
-- do lbfgs update (update memory)
- local y = g:clone():add(-1, g_old) -- g - g_old
- local s = d:clone():mul(t) -- d*t
+ local y = table.remove(state.dir_bufs) -- pop
+ local s = table.remove(state.stp_bufs)
+ y:copy(g):add(-1, g_old) -- g - g_old
+ s:copy(d):mul(t) -- d*t
local ys = y:dot(s) -- y*s
if ys > 1e-10 then
-- updating memory
if #old_dirs == nCorrection then
-- shift history by one (limited-memory)
- local prev_old_dirs = old_dirs
- local prev_old_stps = old_stps
- old_dirs = {}
- old_stps = {}
- for i = 2,#prev_old_dirs do
- append(old_dirs, prev_old_dirs[i])
- append(old_stps, prev_old_stps[i])
- end
+ local removed1 = table.remove(old_dirs, 1)
+ local removed2 = table.remove(old_stps, 1)
+ table.insert(state.dir_bufs, removed1)
+ table.insert(state.stp_bufs, removed2)
end
-- store new direction/step
- append(old_dirs, s)
- append(old_stps, y)
+ table.insert(old_dirs, s)
+ table.insert(old_stps, y)
-- update scale of initial Hessian approximation
Hdiag = ys / y:dot(y) -- (y*y)
-
- -- cleanup
- collectgarbage()
+ else
+ -- put y and s back into the buffer pool
+ table.insert(state.dir_bufs, y)
+ table.insert(state.stp_bufs, s)
end
-- compute the approximate (L-BFGS) inverse Hessian
-- multiplied by the gradient
- local p = g:size(1)
local k = #old_dirs
- state.ro = state.ro or zeros(nCorrection); local ro = state.ro
+ -- need to be accessed element-by-element, so don't re-type tensor:
+ state.ro = state.ro or torch.Tensor(nCorrection); local ro = state.ro
for i = 1,k do
ro[i] = 1 / old_stps[i]:dot(old_dirs[i])
end
- state.q = state.q or zeros(nCorrection+1,p):typeAs(g)
- local q = state.q
- state.r = state.r or zeros(nCorrection+1,p):typeAs(g)
- local r = state.r
- state.al = state.al or zeros(nCorrection):typeAs(g)
- local al = state.al
- state.be = state.be or zeros(nCorrection):typeAs(g)
- local be = state.be
-
- q[k+1] = g:clone():mul(-1) -- -g
+ -- iteration in L-BFGS loop collapsed to use just one buffer
+ local q = tmp1 -- reuse tmp1 for the q buffer
+ -- need to be accessed element-by-element, so don't re-type tensor:
+ state.al = state.al or torch.zeros(nCorrection) local al = state.al
+ q:copy(g):mul(-1) -- -g
for i = k,1,-1 do
- al[i] = old_dirs[i]:dot(q[i+1]) * ro[i]
- q[i] = q[i+1]
- q[i]:add(-al[i], old_stps[i])
+ al[i] = old_dirs[i]:dot(q) * ro[i]
+ q:add(-al[i], old_stps[i])
end
-- multiply by initial Hessian
- r[1] = q[1]:clone():mul(Hdiag) -- q[1] * Hdiag
-
+ r = d -- share the same buffer, since we don't need the old d
+ r:copy(q):mul(Hdiag) -- q[1] * Hdiag
for i = 1,k do
- be[i] = old_stps[i]:dot(r[i]) * ro[i]
- r[i+1] = r[i]
- r[i+1]:add((al[i] - be[i]), old_dirs[i])
+ local be_i = old_stps[i]:dot(r) * ro[i]
+ r:add(al[i]-be_i, old_dirs[i])
end
-
- -- final direction:
- d:copy(r[k+1])
+ -- final direction is in r/d (same object)
end
- g_old = g:clone()
+ g_old = g_old or g:clone()
+ g_old:copy(g)
f_old = f
------------------------------------------------------------
@@ -197,7 +197,7 @@ function optim.lbfgs(opfunc, x, config, state)
if lineSearch and type(lineSearch) == 'function' then
-- perform line search, using user function
f,g,x,t,lsFuncEval = lineSearch(opfunc,x,t,d,f,g,gtd,lineSearchOpts)
- append(f_hist, f)
+ table.insert(f_hist, f)
else
-- no line search, simply move with fixed-step
x:add(t,d)
@@ -207,7 +207,7 @@ function optim.lbfgs(opfunc, x, config, state)
-- no use to re-evaluate that function here
f,g = opfunc(x)
lsFuncEval = 1
- append(f_hist, f)
+ table.insert(f_hist, f)
end
end