Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-07-30 01:44:09 +0300
committerSoumith Chintala <soumith@gmail.com>2016-07-30 01:44:09 +0300
commita959ba3fd6abb3526836b9e3ae514e2adc400dd6 (patch)
tree5dd0c211f02b070671282705de176a9df214e841
parent83d952d47cafa710bb5cf95be797f4f2352baa0c (diff)
fixing to be tensor type agnostic
-rw-r--r--lswolfe.lua29
1 files changed, 14 insertions, 15 deletions
diff --git a/lswolfe.lua b/lswolfe.lua
index d7beb4f..0afbdbe 100644
--- a/lswolfe.lua
+++ b/lswolfe.lua
@@ -34,7 +34,6 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
local abs = torch.abs
local min = math.min
local max = math.max
- local Tensor = torch.Tensor
-- verbose function
local function verbose(...)
@@ -56,25 +55,25 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
while LSiter < maxIter do
-- check conditions:
if (f_new > (f + c1*t*gtd)) or (LSiter > 1 and f_new >= f_prev) then
- bracket = Tensor{t_prev,t}
- bracketFval = Tensor{f_prev,f_new}
- bracketGval = Tensor(2,g_new:size(1))
+ bracket = x.new{t_prev,t}
+ bracketFval = x.new{f_prev,f_new}
+ bracketGval = x.new(2,g_new:size(1))
bracketGval[1] = g_prev
bracketGval[2] = g_new
break
elseif abs(gtd_new) <= -c2*gtd then
- bracket = Tensor{t}
- bracketFval = Tensor{f_new}
- bracketGval = Tensor(1,g_new:size(1))
+ bracket = x.new{t}
+ bracketFval = x.new{f_new}
+ bracketGval = x.new(1,g_new:size(1))
bracketGval[1] = g_new
done = true
break
elseif gtd_new >= 0 then
- bracket = Tensor{t_prev,t}
- bracketFval = Tensor{f_prev,f_new}
- bracketGval = Tensor(2,g_new:size(1))
+ bracket = x.new{t_prev,t}
+ bracketFval = x.new{f_prev,f_new}
+ bracketGval = x.new(2,g_new:size(1))
bracketGval[1] = g_prev
bracketGval[2] = g_new
break
@@ -86,7 +85,7 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
t_prev = t
local minStep = t + 0.01*(t-tmp)
local maxStep = t*10
- t = optim.polyinterp(Tensor{{tmp,f_prev,gtd_prev},
+ t = optim.polyinterp(x.new{{tmp,f_prev,gtd_prev},
{t,f_new,gtd_new}},
minStep, maxStep)
@@ -104,9 +103,9 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
-- reached max nb of iterations?
if LSiter == maxIter then
- bracket = Tensor{0,t}
- bracketFval = Tensor{f,f_new}
- bracketGval = Tensor(2,g_new:size(1))
+ bracket = x.new{0,t}
+ bracketFval = x.new{f,f_new}
+ bracketGval = x.new(2,g_new:size(1))
bracketGval[1] = g
bracketGval[2] = g_new
end
@@ -123,7 +122,7 @@ function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
local HIpos = -LOpos+3
-- compute new trial value
- t = optim.polyinterp(Tensor{{bracket[1],bracketFval[1],bracketGval[1]*d},
+ t = optim.polyinterp(x.new{{bracket[1],bracketFval[1],bracketGval[1]*d},
{bracket[2],bracketFval[2],bracketGval[2]*d}})
-- test what we are making sufficient progress