Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-10-09 19:41:10 +0300
committerGitHub <noreply@github.com>2016-10-09 19:41:10 +0300
commit9f6367cff15592db3321a2913f47dacb2abc3c3e (patch)
tree84c7bc8df5486b67ba3235cf8597e0f2783ba456
parent2fe04c2e6a7299faec1f74f7c87aa4efbde38365 (diff)
parentd398cd38f0bab6de0320cf71acb3322da84e6382 (diff)
Merge pull request #138 from DmitryUlyanov/master
Fix polyinterp to let lbfgs wtih lswolfe work on GPU
-rw-r--r--polyinterp.lua50
1 files changed, 14 insertions, 36 deletions
diff --git a/polyinterp.lua b/polyinterp.lua
index 5975981..35317ac 100644
--- a/polyinterp.lua
+++ b/polyinterp.lua
@@ -32,23 +32,23 @@ local function roots(c)
local n = c:size(1)-1
if n == 1 then
- local e = torch.Tensor({{-c[2]/c[1], 0}})
+ local e = c.new({{-c[2]/c[1], 0}})
if nz > 0 then
- return torch.cat(e, torch.zeros(nz, 2), 1)
+ return torch.cat(e, c.new(nz, 2):zero(), 1)
else
return e
end
elseif n > 1 then
- local A = torch.diag(torch.ones(n-1),-1)
+ local A = torch.diag(c.new(n-1):fill(1),-1)
A[1] = -c[{ {2,n+1} }]/c[1];
local e = torch.eig(A,'N')
if nz > 0 then
- return torch.cat(e, torch.zeros(nz,2), 1)
+ return torch.cat(e, c.new(nz,2):zero(), 1)
else
return e
end
else
- return torch.zeros(nz,2)
+ return c.new(nz,2):zero()
end
end
@@ -60,7 +60,7 @@ end
local function imag(x)
if type(x) == 'number' then return 0 end
if x:nDimension() == 1 then
- return torch.zeros(x:size(1))
+ return x.new(x:size(1)):zero()
else
return x[{ {}, 2}]
end
@@ -95,8 +95,6 @@ function optim.polyinterp(points,xminBound,xmaxBound)
-- locals
local sqrt = torch.sqrt
local mean = torch.mean
- local Tensor = torch.Tensor
- local zeros = torch.zeros
local max = math.max
local min = math.min
@@ -147,10 +145,10 @@ function optim.polyinterp(points,xminBound,xmaxBound)
xmaxBound = xmaxBound or xmax
-- Add constraints on function values
- local A = zeros(nPoints*2,order+1)
- local b = zeros(nPoints*2,1)
+ local A = points.new(nPoints*2,order+1):zero()
+ local b = points.new(nPoints*2,1):zero()
for i = 1,nPoints do
- local constraint = zeros(order+1)
+ local constraint = points.new(order+1):zero()
for j = order,0,-1 do
constraint[order-j+1] = points[{i,1}]^j
end
@@ -160,7 +158,7 @@ function optim.polyinterp(points,xminBound,xmaxBound)
-- Add constraints based on derivatives
for i = 1,nPoints do
- local constraint = zeros(order+1)
+ local constraint = points.new(order+1):zero()
for j = 1,order do
constraint[j] = (order-j+1)*points[{i,1}]^(order-j)
end
@@ -172,13 +170,10 @@ function optim.polyinterp(points,xminBound,xmaxBound)
local res = torch.gels(b,A)
local params = res[{ {1,nPoints*2} }]:squeeze()
- --print(A)
- --print(b)
- --print(params)
params[torch.le(torch.abs(params),1e-12)]=0
-- Compute Critical Points
- local dParams = zeros(order);
+ local dParams = points.new(order):zero();
for i = 1,params:size(1)-1 do
dParams[i] = params[i]*(order-i+1)
end
@@ -188,46 +183,29 @@ function optim.polyinterp(points,xminBound,xmaxBound)
if torch.ne(dParams,dParams):max() > 0 or torch.eq(dParams,math.huge):max() > 0 then
nans = true
end
- -- for i = 1,dParams:size(1) do
- -- if dParams[i] ~= dParams[i] or dParams[i] == math.huge then
- -- nans = true
- -- break
- -- end
- -- end
- local cp = torch.cat(Tensor{xminBound,xmaxBound},points[{{},1}])
+
+ local cp = torch.cat(points.new{xminBound,xmaxBound},points[{{},1}])
if not nans then
local cproots = roots(dParams)
- local cpi = zeros(cp:size(1),2)
+ local cpi = points.new(cp:size(1),2):zero()
cpi[{ {1,cp:size(1)} , 1 }] = cp
cp = torch.cat(cpi,cproots,1)
end
- --print(dParams)
- --print(cp)
-
-- Test Critical Points
local fmin = math.huge
-- Default to Bisection if no critical points valid:
minPos = (xminBound+xmaxBound)/2
- --print(minPos,fmin)
- --print(xminBound,xmaxBound)
for i = 1,cp:size(1) do
local xCP = cp[{ {i,i} , {} }]
- --print('xcp=')
- --print(xCP)
local ixCP = imag(xCP)[1]
local rxCP = real(xCP)[1]
if ixCP == 0 and rxCP >= xminBound and rxCP <= xmaxBound then
local fCP = polyval(params,rxCP)
- --print('fcp=')
- --print(fCP)
- --print(fCP < fmin)
if fCP < fmin then
minPos = rxCP
fmin = fCP
- --print('u',minPos,fmin)
end
- --print('v',minPos,fmin)
end
end
return minPos,fmin