diff options
author | Koray Kavukcuoglu <koray@kavukcuoglu.org> | 2012-05-04 00:47:50 +0400 |
---|---|---|
committer | Koray Kavukcuoglu <koray@kavukcuoglu.org> | 2012-05-04 00:47:50 +0400 |
commit | 9f26c9bd2e3619bcf06282c45afcd03ce5fa5942 (patch) | |
tree | 7c53c07c7f1fe8456fc0f52118a6fd1d64858648 | |
parent | d5414bc7ff18e5f379ed41d0468235b86d525795 (diff) |
polyinterp with extrapolation
-rw-r--r-- | polyinterp.lua | 91 |
1 files changed, 73 insertions, 18 deletions
diff --git a/polyinterp.lua b/polyinterp.lua index f8b7e9d..005243c 100644 --- a/polyinterp.lua +++ b/polyinterp.lua @@ -7,11 +7,42 @@ local function isnan(x) end local function roots(c) + local tol=1e-12 + c[torch.lt(torch.abs(c),tol)]=0 + + local nonzero = torch.ne(c,0) + if nonzero:max() == 0 then + return 0 + end + + -- first non-zero + local _,pos = torch.max(nonzero,1) + pos = pos[1] + c=c[{ {pos,-1} }] + + local nz = 0 + for i=c:size(1),1,-1 do + if c[i] ~= 0 then + break + else + nz = nz + 1 + end + end + c=c[{ {1,c:size(1)-nz} }] + local n = c:size(1)-1 - local A = torch.diag(torch.ones(n-1),-1) - A[1] = -c[{ {2,n+1} }]/c[1]; - local e = torch.eig(A,'N') - return e + if n > 0 then + local A = torch.diag(torch.ones(n-1),-1) + A[1] = -c[{ {2,n+1} }]/c[1]; + local e = torch.eig(A,'N') + if nz > 0 then + return torch.cat(e,torch.zeros(nz,2)) + else + return e + end + else + return torch.zeros(nz,2) + end end local function real(x) @@ -29,11 +60,16 @@ local function imag(x) end local function polyval(p,x) - if type(x) == 'number' then x=torch.Tensor{x} end - local val = x.new(x:size(1)) local pwr = p:size(1) - p:apply(function(pc) pwr = pwr-1; val:add(pc,torch.pow(x,pwr)); return pc end) - return val + if type(x) == 'number' then + local val = 0 + p:apply(function(pc) pwr = pwr-1; val = val + pc*x^pwr; return pc end) + return val + else + local val = x.new(x:size(1)) + p:apply(function(pc) pwr = pwr-1; val:add(pc,torch.pow(x,pwr)); return pc end) + return val + end end ---------------------------------------------------------------------- @@ -95,7 +131,7 @@ function optim.polyinterp(points,xminBound,xmaxBound) end -- TODO: get the code below to work! - error('<optim.polyinterp> extrapolation not implemented yet...') + --error('<optim.polyinterp> extrapolation not implemented yet...') -- Compute Bounds of Interpolation Area local xmin = points[{{},1}]:min() @@ -129,6 +165,11 @@ function optim.polyinterp(points,xminBound,xmaxBound) local res = torch.gels(b,A) local params = res[{ {1,nPoints*2} }]:squeeze() + --print(A) + --print(b) + --print(params) + params[torch.le(torch.abs(params),1e-12)]=0 + -- Compute Critical Points local dParams = zeros(order); for i = 1,params:size(1)-1 do @@ -147,26 +188,40 @@ function optim.polyinterp(points,xminBound,xmaxBound) -- end -- end local cp = torch.cat(Tensor{xminBound,xmaxBound},points[{{},1}]) - if not nans + if not nans then local cproots = roots(dParams) local cpi = zeros(cp:size(1),2) cpi[{ {1,cp:size(1)} , 1 }] = cp - cp = torch.cat(cpi,cproots) + cp = torch.cat(cpi,cproots,1) end + --print(dParams) + --print(cp) + -- Test Critical Points local fmin = math.huge -- Default to Bisection if no critical points valid: minPos = (xminBound+xmaxBound)/2 + --print(minPos,fmin) + --print(xminBound,xmaxBound) for i = 1,cp:size(1) do - local xCP = cp[i] - if imag(xCP)==0 and xCP >= xminBound and xCP <= xmaxBound then - local fCP = polyval(params,xCP) - if imag(fCP)==0 and fCP < fmin then - minPos = real(xCP) - fmin = real(fCP) + local xCP = cp[{ {i,i} , {} }] + --print('xcp=') + --print(xCP) + local ixCP = imag(xCP)[1] + local rxCP = real(xCP)[1] + if ixCP == 0 and rxCP >= xminBound and rxCP <= xmaxBound then + local fCP = polyval(params,rxCP) + --print('fcp=') + --print(fCP) + --print(fCP < fmin) + if fCP < fmin then + minPos = rxCP + fmin = fCP + --print('u',minPos,fmin) end + --print('v',minPos,fmin) end end - return minPos + return minPos,fmin end |