diff options
author | David Pfau <pfau@google.com> | 2015-05-19 16:55:39 +0300 |
---|---|---|
committer | David Pfau <pfau@google.com> | 2015-05-19 16:55:39 +0300 |
commit | a329ac762524c78c8eec9dc0c1aa1ab480e7c94c (patch) | |
tree | 23ce234638fd9368437227c26be7191705a96325 | |
parent | 446de4f45f3e901adc889f4d6a49051028e437e8 (diff) |
Added CUDA support for cg.lua
-rw-r--r-- | cg.lua | 17 |
1 files changed, 11 insertions, 6 deletions
@@ -42,6 +42,11 @@ function optim.cg(opfunc, x, config, state) local ratio = config.ratio or 100 local maxEval = config.maxEval or maxIter*1.25 local red = 1 + local cuda = torch.typename(x) == 'torch.CudaTensor' + local function newTensor() + if cuda then return torch.CudaTensor() end + return torch.Tensor() + end local verbose = config.verbose or 0 @@ -54,9 +59,9 @@ function optim.cg(opfunc, x, config, state) local d1,d2,d3 = 0,0,0 local f1,f2,f3 = 0,0,0 - local df1 = state.df1 or torch.Tensor() - local df2 = state.df2 or torch.Tensor() - local df3 = state.df3 or torch.Tensor() + local df1 = state.df1 or newTensor() + local df2 = state.df2 or newTensor() + local df3 = state.df3 or newTensor() local tdf df1:resizeAs(x) @@ -64,13 +69,13 @@ function optim.cg(opfunc, x, config, state) df3:resizeAs(x) -- search direction - local s = state.s or torch.Tensor() + local s = state.s or newTensor() s:resizeAs(x) -- we need a temp storage for X - local x0 = state.x0 or torch.Tensor() + local x0 = state.x0 or newTensor() local f0 = 0 - local df0 = state.df0 or torch.Tensor() + local df0 = state.df0 or newTensor() x0:resizeAs(x) df0:resizeAs(x) |