diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-10-09 19:41:10 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-09 19:41:10 +0300 |
commit | 9f6367cff15592db3321a2913f47dacb2abc3c3e (patch) | |
tree | 84c7bc8df5486b67ba3235cf8597e0f2783ba456 | |
parent | 2fe04c2e6a7299faec1f74f7c87aa4efbde38365 (diff) | |
parent | d398cd38f0bab6de0320cf71acb3322da84e6382 (diff) |
Merge pull request #138 from DmitryUlyanov/master
Fix polyinterp to let lbfgs wtih lswolfe work on GPU
-rw-r--r-- | polyinterp.lua | 50 |
1 files changed, 14 insertions, 36 deletions
diff --git a/polyinterp.lua b/polyinterp.lua index 5975981..35317ac 100644 --- a/polyinterp.lua +++ b/polyinterp.lua @@ -32,23 +32,23 @@ local function roots(c) local n = c:size(1)-1 if n == 1 then - local e = torch.Tensor({{-c[2]/c[1], 0}}) + local e = c.new({{-c[2]/c[1], 0}}) if nz > 0 then - return torch.cat(e, torch.zeros(nz, 2), 1) + return torch.cat(e, c.new(nz, 2):zero(), 1) else return e end elseif n > 1 then - local A = torch.diag(torch.ones(n-1),-1) + local A = torch.diag(c.new(n-1):fill(1),-1) A[1] = -c[{ {2,n+1} }]/c[1]; local e = torch.eig(A,'N') if nz > 0 then - return torch.cat(e, torch.zeros(nz,2), 1) + return torch.cat(e, c.new(nz,2):zero(), 1) else return e end else - return torch.zeros(nz,2) + return c.new(nz,2):zero() end end @@ -60,7 +60,7 @@ end local function imag(x) if type(x) == 'number' then return 0 end if x:nDimension() == 1 then - return torch.zeros(x:size(1)) + return x.new(x:size(1)):zero() else return x[{ {}, 2}] end @@ -95,8 +95,6 @@ function optim.polyinterp(points,xminBound,xmaxBound) -- locals local sqrt = torch.sqrt local mean = torch.mean - local Tensor = torch.Tensor - local zeros = torch.zeros local max = math.max local min = math.min @@ -147,10 +145,10 @@ function optim.polyinterp(points,xminBound,xmaxBound) xmaxBound = xmaxBound or xmax -- Add constraints on function values - local A = zeros(nPoints*2,order+1) - local b = zeros(nPoints*2,1) + local A = points.new(nPoints*2,order+1):zero() + local b = points.new(nPoints*2,1):zero() for i = 1,nPoints do - local constraint = zeros(order+1) + local constraint = points.new(order+1):zero() for j = order,0,-1 do constraint[order-j+1] = points[{i,1}]^j end @@ -160,7 +158,7 @@ function optim.polyinterp(points,xminBound,xmaxBound) -- Add constraints based on derivatives for i = 1,nPoints do - local constraint = zeros(order+1) + local constraint = points.new(order+1):zero() for j = 1,order do constraint[j] = (order-j+1)*points[{i,1}]^(order-j) end @@ -172,13 +170,10 @@ function optim.polyinterp(points,xminBound,xmaxBound) local res = torch.gels(b,A) local params = res[{ {1,nPoints*2} }]:squeeze() - --print(A) - --print(b) - --print(params) params[torch.le(torch.abs(params),1e-12)]=0 -- Compute Critical Points - local dParams = zeros(order); + local dParams = points.new(order):zero(); for i = 1,params:size(1)-1 do dParams[i] = params[i]*(order-i+1) end @@ -188,46 +183,29 @@ function optim.polyinterp(points,xminBound,xmaxBound) if torch.ne(dParams,dParams):max() > 0 or torch.eq(dParams,math.huge):max() > 0 then nans = true end - -- for i = 1,dParams:size(1) do - -- if dParams[i] ~= dParams[i] or dParams[i] == math.huge then - -- nans = true - -- break - -- end - -- end - local cp = torch.cat(Tensor{xminBound,xmaxBound},points[{{},1}]) + + local cp = torch.cat(points.new{xminBound,xmaxBound},points[{{},1}]) if not nans then local cproots = roots(dParams) - local cpi = zeros(cp:size(1),2) + local cpi = points.new(cp:size(1),2):zero() cpi[{ {1,cp:size(1)} , 1 }] = cp cp = torch.cat(cpi,cproots,1) end - --print(dParams) - --print(cp) - -- Test Critical Points local fmin = math.huge -- Default to Bisection if no critical points valid: minPos = (xminBound+xmaxBound)/2 - --print(minPos,fmin) - --print(xminBound,xmaxBound) for i = 1,cp:size(1) do local xCP = cp[{ {i,i} , {} }] - --print('xcp=') - --print(xCP) local ixCP = imag(xCP)[1] local rxCP = real(xCP)[1] if ixCP == 0 and rxCP >= xminBound and rxCP <= xmaxBound then local fCP = polyval(params,rxCP) - --print('fcp=') - --print(fCP) - --print(fCP < fmin) if fCP < fmin then minPos = rxCP fmin = fCP - --print('u',minPos,fmin) end - --print('v',minPos,fmin) end end return minPos,fmin |