diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-05-02 23:08:08 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-05-02 23:08:08 +0400 |
commit | 6ad83ba5657a0d8fb317ad5f1e4e82d99793df18 (patch) | |
tree | 3554171245b2bbe71e09fe12a81e6d876fe2f331 | |
parent | 6c45836d1f975ba15f7aee50e093f1944b27b5fe (diff) | |
parent | bc67896b5b9fecaf3a904ff81aad0e548674dcb2 (diff) |
Merge branch 'master' of https://github.com/koraykv/optim
-rw-r--r-- | fista.lua | 58 |
1 files changed, 33 insertions, 25 deletions
@@ -102,27 +102,50 @@ function optim.FistaLS(f, g, pl, xinit, params) -- Q(beta,y) = F(y) + <beta-y , \Grad(F(y))> + L/2||beta-y||^2 + G(beta) Q = fy + Q2 + Q3 + if verbose then + print(string.format('nline=%d L=%g fply=%g Q=%g fy=%g Q2=%g Q3=%g',nline,L,fply,Q,fy,Q2,Q3)) + end -- check if F(beta) < Q(pl(y),\t) if fply <= Q then --and Fply + Gply <= F then -- now evaluate G here - gply = g(xk) linesearchdone = true elseif nline >= maxline then linesearchdone = true xk:copy(xkm) -- if we can't find a better point, current iter = previous iter - fply = f(xk) - gply = g(xk) --print('oops') else L = L * Lstep end nline = nline + 1 - if verbose then - print(niter,linesearchdone,nline,L,fy,Q2,Q3,Q,fply) - end end -- end line search --------------------------------------------- + + --------------------------------------------- + -- FISTA + --------------------------------------------- + if doFistaUpdate then + -- do the FISTA step + tkp = (1 + math.sqrt(1 + 4*tk*tk)) / 2 + -- x(k-1) = x(k-1) - x(k) + xkm:add(-1,xk) + -- y(k+1) = x(k) + (1-t(k)/t(k+1))*(x(k-1)-x(k)) + y:copy(xk) + y:add( (1-tk)/tkp , xkm) + -- store for next iterations + -- x(k-1) = x(k) + xkm:copy(xk) + else + y:copy(xk) + end + -- t(k) = t(k+1) + tk = tkp + fply = f(y) + gply = g(y) + if verbose then + print(string.format('iter=%d eold=%g enew=%g',niter,fval,fply+gply)) + end + niter = niter + 1 -- bookeeping @@ -143,11 +166,13 @@ function optim.FistaLS(f, g, pl, xinit, params) -- are we done? if niter > 1 and math.abs(history[niter].F - history[niter-1].F) <= errthres then converged = true - return xk,history + xinit:copy(y) + return y,history end if niter >= maxiter then - return xk,history + xinit:copy(y) + return y,history end --if niter > 1 and history[niter].F > history[niter-1].F then @@ -155,23 +180,6 @@ function optim.FistaLS(f, g, pl, xinit, params) --converged = true --return xk,history --end - - if doFistaUpdate then - -- do the FISTA step - tkp = (1 + math.sqrt(1 + 4*tk*tk)) / 2 - -- x(k-1) = x(k-1) - x(k) - xkm:add(-1,xk) - -- y(k+1) = x(k) + (1-t(k)/t(k+1))*(x(k-1)-x(k)) - y:copy(xk) - y:add( (1-tk)/tkp , xkm) - -- store for next iterations - -- x(k-1) = x(k) - xkm:copy(xk) - else - y:copy(xk) - end - -- t(k) = t(k+1) - tk = tkp end error('not supposed to be here') end |