Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKoray Kavukcuoglu <koray@kavukcuoglu.org>2012-05-04 00:47:50 +0400
committerKoray Kavukcuoglu <koray@kavukcuoglu.org>2012-05-04 00:47:50 +0400
commit9f26c9bd2e3619bcf06282c45afcd03ce5fa5942 (patch)
tree7c53c07c7f1fe8456fc0f52118a6fd1d64858648
parentd5414bc7ff18e5f379ed41d0468235b86d525795 (diff)
polyinterp with extrapolation
-rw-r--r--polyinterp.lua91
1 files changed, 73 insertions, 18 deletions
diff --git a/polyinterp.lua b/polyinterp.lua
index f8b7e9d..005243c 100644
--- a/polyinterp.lua
+++ b/polyinterp.lua
@@ -7,11 +7,42 @@ local function isnan(x)
end
local function roots(c)
+ local tol=1e-12
+ c[torch.lt(torch.abs(c),tol)]=0
+
+ local nonzero = torch.ne(c,0)
+ if nonzero:max() == 0 then
+ return 0
+ end
+
+ -- first non-zero
+ local _,pos = torch.max(nonzero,1)
+ pos = pos[1]
+ c=c[{ {pos,-1} }]
+
+ local nz = 0
+ for i=c:size(1),1,-1 do
+ if c[i] ~= 0 then
+ break
+ else
+ nz = nz + 1
+ end
+ end
+ c=c[{ {1,c:size(1)-nz} }]
+
local n = c:size(1)-1
- local A = torch.diag(torch.ones(n-1),-1)
- A[1] = -c[{ {2,n+1} }]/c[1];
- local e = torch.eig(A,'N')
- return e
+ if n > 0 then
+ local A = torch.diag(torch.ones(n-1),-1)
+ A[1] = -c[{ {2,n+1} }]/c[1];
+ local e = torch.eig(A,'N')
+ if nz > 0 then
+ return torch.cat(e,torch.zeros(nz,2))
+ else
+ return e
+ end
+ else
+ return torch.zeros(nz,2)
+ end
end
local function real(x)
@@ -29,11 +60,16 @@ local function imag(x)
end
local function polyval(p,x)
- if type(x) == 'number' then x=torch.Tensor{x} end
- local val = x.new(x:size(1))
local pwr = p:size(1)
- p:apply(function(pc) pwr = pwr-1; val:add(pc,torch.pow(x,pwr)); return pc end)
- return val
+ if type(x) == 'number' then
+ local val = 0
+ p:apply(function(pc) pwr = pwr-1; val = val + pc*x^pwr; return pc end)
+ return val
+ else
+ local val = x.new(x:size(1))
+ p:apply(function(pc) pwr = pwr-1; val:add(pc,torch.pow(x,pwr)); return pc end)
+ return val
+ end
end
----------------------------------------------------------------------
@@ -95,7 +131,7 @@ function optim.polyinterp(points,xminBound,xmaxBound)
end
-- TODO: get the code below to work!
- error('<optim.polyinterp> extrapolation not implemented yet...')
+ --error('<optim.polyinterp> extrapolation not implemented yet...')
-- Compute Bounds of Interpolation Area
local xmin = points[{{},1}]:min()
@@ -129,6 +165,11 @@ function optim.polyinterp(points,xminBound,xmaxBound)
local res = torch.gels(b,A)
local params = res[{ {1,nPoints*2} }]:squeeze()
+ --print(A)
+ --print(b)
+ --print(params)
+ params[torch.le(torch.abs(params),1e-12)]=0
+
-- Compute Critical Points
local dParams = zeros(order);
for i = 1,params:size(1)-1 do
@@ -147,26 +188,40 @@ function optim.polyinterp(points,xminBound,xmaxBound)
-- end
-- end
local cp = torch.cat(Tensor{xminBound,xmaxBound},points[{{},1}])
- if not nans
+ if not nans then
local cproots = roots(dParams)
local cpi = zeros(cp:size(1),2)
cpi[{ {1,cp:size(1)} , 1 }] = cp
- cp = torch.cat(cpi,cproots)
+ cp = torch.cat(cpi,cproots,1)
end
+ --print(dParams)
+ --print(cp)
+
-- Test Critical Points
local fmin = math.huge
-- Default to Bisection if no critical points valid:
minPos = (xminBound+xmaxBound)/2
+ --print(minPos,fmin)
+ --print(xminBound,xmaxBound)
for i = 1,cp:size(1) do
- local xCP = cp[i]
- if imag(xCP)==0 and xCP >= xminBound and xCP <= xmaxBound then
- local fCP = polyval(params,xCP)
- if imag(fCP)==0 and fCP < fmin then
- minPos = real(xCP)
- fmin = real(fCP)
+ local xCP = cp[{ {i,i} , {} }]
+ --print('xcp=')
+ --print(xCP)
+ local ixCP = imag(xCP)[1]
+ local rxCP = real(xCP)[1]
+ if ixCP == 0 and rxCP >= xminBound and rxCP <= xmaxBound then
+ local fCP = polyval(params,rxCP)
+ --print('fcp=')
+ --print(fCP)
+ --print(fCP < fmin)
+ if fCP < fmin then
+ minPos = rxCP
+ fmin = fCP
+ --print('u',minPos,fmin)
end
+ --print('v',minPos,fmin)
end
end
- return minPos
+ return minPos,fmin
end