Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDmitry Ulyanov <dmitry.ulyanov.msu@gmail.com>2016-10-09 18:46:24 +0300
committerDmitry Ulyanov <dmitry.ulyanov.msu@gmail.com>2016-10-09 18:46:24 +0300
commit890f7fa013ba9e59b74578c6d5b916014aac7efd (patch)
tree3fd340ff5a7d76d760ac6f57e9e9cc5a0be9254e
parent2fe04c2e6a7299faec1f74f7c87aa4efbde38365 (diff)
fix polyinterp, so lswolfe can be used with CUDA
-rw-r--r--polyinterp.lua33
1 files changed, 18 insertions, 15 deletions
diff --git a/polyinterp.lua b/polyinterp.lua
index 5975981..5c7e49b 100644
--- a/polyinterp.lua
+++ b/polyinterp.lua
@@ -32,23 +32,23 @@ local function roots(c)
local n = c:size(1)-1
if n == 1 then
- local e = torch.Tensor({{-c[2]/c[1], 0}})
+ local e = c.new({{-c[2]/c[1], 0}})
if nz > 0 then
- return torch.cat(e, torch.zeros(nz, 2), 1)
+ return torch.cat(e, c.new(nz, 2):zero(), 1)
else
return e
end
elseif n > 1 then
- local A = torch.diag(torch.ones(n-1),-1)
+ local A = torch.diag(c.new(n-1):fill(1),-1)
A[1] = -c[{ {2,n+1} }]/c[1];
local e = torch.eig(A,'N')
if nz > 0 then
- return torch.cat(e, torch.zeros(nz,2), 1)
+ return torch.cat(e, c.new(nz,2):zero(), 1)
else
return e
end
else
- return torch.zeros(nz,2)
+ return c.new(nz,2):zero()
end
end
@@ -60,7 +60,7 @@ end
local function imag(x)
if type(x) == 'number' then return 0 end
if x:nDimension() == 1 then
- return torch.zeros(x:size(1))
+ return x.new(x:size(1)):zero()
else
return x[{ {}, 2}]
end
@@ -95,8 +95,8 @@ function optim.polyinterp(points,xminBound,xmaxBound)
-- locals
local sqrt = torch.sqrt
local mean = torch.mean
- local Tensor = torch.Tensor
- local zeros = torch.zeros
+ -- local Tensor = torch.Tensor
+ -- local zeros = torch.zeros
local max = math.max
local min = math.min
@@ -147,10 +147,10 @@ function optim.polyinterp(points,xminBound,xmaxBound)
xmaxBound = xmaxBound or xmax
-- Add constraints on function values
- local A = zeros(nPoints*2,order+1)
- local b = zeros(nPoints*2,1)
+ local A = points.new(nPoints*2,order+1):zero()
+ local b = points.new(nPoints*2,1):zero()
for i = 1,nPoints do
- local constraint = zeros(order+1)
+ local constraint = points.new(order+1):zero()
for j = order,0,-1 do
constraint[order-j+1] = points[{i,1}]^j
end
@@ -160,7 +160,7 @@ function optim.polyinterp(points,xminBound,xmaxBound)
-- Add constraints based on derivatives
for i = 1,nPoints do
- local constraint = zeros(order+1)
+ local constraint = points.new(order+1):zero()
for j = 1,order do
constraint[j] = (order-j+1)*points[{i,1}]^(order-j)
end
@@ -169,6 +169,7 @@ function optim.polyinterp(points,xminBound,xmaxBound)
end
-- Find interpolating polynomial
+ -- print(A:size())
local res = torch.gels(b,A)
local params = res[{ {1,nPoints*2} }]:squeeze()
@@ -178,7 +179,7 @@ function optim.polyinterp(points,xminBound,xmaxBound)
params[torch.le(torch.abs(params),1e-12)]=0
-- Compute Critical Points
- local dParams = zeros(order);
+ local dParams = points.new(order):zero();
for i = 1,params:size(1)-1 do
dParams[i] = params[i]*(order-i+1)
end
@@ -194,10 +195,12 @@ function optim.polyinterp(points,xminBound,xmaxBound)
-- break
-- end
-- end
- local cp = torch.cat(Tensor{xminBound,xmaxBound},points[{{},1}])
+ -- print(points)
+
+ local cp = torch.cat(points.new{xminBound,xmaxBound},points[{{},1}])
if not nans then
local cproots = roots(dParams)
- local cpi = zeros(cp:size(1),2)
+ local cpi = points.new(cp:size(1),2):zero()
cpi[{ {1,cp:size(1)} , 1 }] = cp
cp = torch.cat(cpi,cproots,1)
end