diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-07-30 01:44:09 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-07-30 01:44:09 +0300 |
commit | a959ba3fd6abb3526836b9e3ae514e2adc400dd6 (patch) | |
tree | 5dd0c211f02b070671282705de176a9df214e841 | |
parent | 83d952d47cafa710bb5cf95be797f4f2352baa0c (diff) |
fixing to be tensor type agnostic
-rw-r--r-- | lswolfe.lua | 29 |
1 files changed, 14 insertions, 15 deletions
diff --git a/lswolfe.lua b/lswolfe.lua index d7beb4f..0afbdbe 100644 --- a/lswolfe.lua +++ b/lswolfe.lua @@ -34,7 +34,6 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options) local abs = torch.abs local min = math.min local max = math.max - local Tensor = torch.Tensor -- verbose function local function verbose(...) @@ -56,25 +55,25 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options) while LSiter < maxIter do -- check conditions: if (f_new > (f + c1*t*gtd)) or (LSiter > 1 and f_new >= f_prev) then - bracket = Tensor{t_prev,t} - bracketFval = Tensor{f_prev,f_new} - bracketGval = Tensor(2,g_new:size(1)) + bracket = x.new{t_prev,t} + bracketFval = x.new{f_prev,f_new} + bracketGval = x.new(2,g_new:size(1)) bracketGval[1] = g_prev bracketGval[2] = g_new break elseif abs(gtd_new) <= -c2*gtd then - bracket = Tensor{t} - bracketFval = Tensor{f_new} - bracketGval = Tensor(1,g_new:size(1)) + bracket = x.new{t} + bracketFval = x.new{f_new} + bracketGval = x.new(1,g_new:size(1)) bracketGval[1] = g_new done = true break elseif gtd_new >= 0 then - bracket = Tensor{t_prev,t} - bracketFval = Tensor{f_prev,f_new} - bracketGval = Tensor(2,g_new:size(1)) + bracket = x.new{t_prev,t} + bracketFval = x.new{f_prev,f_new} + bracketGval = x.new(2,g_new:size(1)) bracketGval[1] = g_prev bracketGval[2] = g_new break @@ -86,7 +85,7 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options) t_prev = t local minStep = t + 0.01*(t-tmp) local maxStep = t*10 - t = optim.polyinterp(Tensor{{tmp,f_prev,gtd_prev}, + t = optim.polyinterp(x.new{{tmp,f_prev,gtd_prev}, {t,f_new,gtd_new}}, minStep, maxStep) @@ -104,9 +103,9 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options) -- reached max nb of iterations? if LSiter == maxIter then - bracket = Tensor{0,t} - bracketFval = Tensor{f,f_new} - bracketGval = Tensor(2,g_new:size(1)) + bracket = x.new{0,t} + bracketFval = x.new{f,f_new} + bracketGval = x.new(2,g_new:size(1)) bracketGval[1] = g bracketGval[2] = g_new end @@ -123,7 +122,7 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options) local HIpos = -LOpos+3 -- compute new trial value - t = optim.polyinterp(Tensor{{bracket[1],bracketFval[1],bracketGval[1]*d}, + t = optim.polyinterp(x.new{{bracket[1],bracketFval[1],bracketGval[1]*d}, {bracket[2],bracketFval[2],bracketGval[2]*d}}) -- test what we are making sufficient progress |