From 890f7fa013ba9e59b74578c6d5b916014aac7efd Mon Sep 17 00:00:00 2001 From: Dmitry Ulyanov Date: Sun, 9 Oct 2016 18:46:24 +0300 Subject: fix polyinterp, so lswolfe can be used with CUDA --- polyinterp.lua | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/polyinterp.lua b/polyinterp.lua index 5975981..5c7e49b 100644 --- a/polyinterp.lua +++ b/polyinterp.lua @@ -32,23 +32,23 @@ local function roots(c) local n = c:size(1)-1 if n == 1 then - local e = torch.Tensor({{-c[2]/c[1], 0}}) + local e = c.new({{-c[2]/c[1], 0}}) if nz > 0 then - return torch.cat(e, torch.zeros(nz, 2), 1) + return torch.cat(e, c.new(nz, 2):zero(), 1) else return e end elseif n > 1 then - local A = torch.diag(torch.ones(n-1),-1) + local A = torch.diag(c.new(n-1):fill(1),-1) A[1] = -c[{ {2,n+1} }]/c[1]; local e = torch.eig(A,'N') if nz > 0 then - return torch.cat(e, torch.zeros(nz,2), 1) + return torch.cat(e, c.new(nz,2):zero(), 1) else return e end else - return torch.zeros(nz,2) + return c.new(nz,2):zero() end end @@ -60,7 +60,7 @@ end local function imag(x) if type(x) == 'number' then return 0 end if x:nDimension() == 1 then - return torch.zeros(x:size(1)) + return x.new(x:size(1)):zero() else return x[{ {}, 2}] end @@ -95,8 +95,8 @@ function optim.polyinterp(points,xminBound,xmaxBound) -- locals local sqrt = torch.sqrt local mean = torch.mean - local Tensor = torch.Tensor - local zeros = torch.zeros + -- local Tensor = torch.Tensor + -- local zeros = torch.zeros local max = math.max local min = math.min @@ -147,10 +147,10 @@ function optim.polyinterp(points,xminBound,xmaxBound) xmaxBound = xmaxBound or xmax -- Add constraints on function values - local A = zeros(nPoints*2,order+1) - local b = zeros(nPoints*2,1) + local A = points.new(nPoints*2,order+1):zero() + local b = points.new(nPoints*2,1):zero() for i = 1,nPoints do - local constraint = zeros(order+1) + local constraint = points.new(order+1):zero() for j = order,0,-1 do constraint[order-j+1] = points[{i,1}]^j end @@ -160,7 +160,7 @@ function optim.polyinterp(points,xminBound,xmaxBound) -- Add constraints based on derivatives for i = 1,nPoints do - local constraint = zeros(order+1) + local constraint = points.new(order+1):zero() for j = 1,order do constraint[j] = (order-j+1)*points[{i,1}]^(order-j) end @@ -169,6 +169,7 @@ function optim.polyinterp(points,xminBound,xmaxBound) end -- Find interpolating polynomial + -- print(A:size()) local res = torch.gels(b,A) local params = res[{ {1,nPoints*2} }]:squeeze() @@ -178,7 +179,7 @@ function optim.polyinterp(points,xminBound,xmaxBound) params[torch.le(torch.abs(params),1e-12)]=0 -- Compute Critical Points - local dParams = zeros(order); + local dParams = points.new(order):zero(); for i = 1,params:size(1)-1 do dParams[i] = params[i]*(order-i+1) end @@ -194,10 +195,12 @@ function optim.polyinterp(points,xminBound,xmaxBound) -- break -- end -- end - local cp = torch.cat(Tensor{xminBound,xmaxBound},points[{{},1}]) + -- print(points) + + local cp = torch.cat(points.new{xminBound,xmaxBound},points[{{},1}]) if not nans then local cproots = roots(dParams) - local cpi = zeros(cp:size(1),2) + local cpi = points.new(cp:size(1),2):zero() cpi[{ {1,cp:size(1)} , 1 }] = cp cp = torch.cat(cpi,cproots,1) end -- cgit v1.2.3