Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Pfau <pfau@google.com>2015-05-19 16:55:39 +0300
committerDavid Pfau <pfau@google.com>2015-05-19 16:55:39 +0300
commita329ac762524c78c8eec9dc0c1aa1ab480e7c94c (patch)
tree23ce234638fd9368437227c26be7191705a96325
parent446de4f45f3e901adc889f4d6a49051028e437e8 (diff)
Added CUDA support for cg.lua
-rw-r--r--cg.lua17
1 files changed, 11 insertions, 6 deletions
diff --git a/cg.lua b/cg.lua
index a03c220..990fed5 100644
--- a/cg.lua
+++ b/cg.lua
@@ -42,6 +42,11 @@ function optim.cg(opfunc, x, config, state)
local ratio = config.ratio or 100
local maxEval = config.maxEval or maxIter*1.25
local red = 1
+ local cuda = torch.typename(x) == 'torch.CudaTensor'
+ local function newTensor()
+ if cuda then return torch.CudaTensor() end
+ return torch.Tensor()
+ end
local verbose = config.verbose or 0
@@ -54,9 +59,9 @@ function optim.cg(opfunc, x, config, state)
local d1,d2,d3 = 0,0,0
local f1,f2,f3 = 0,0,0
- local df1 = state.df1 or torch.Tensor()
- local df2 = state.df2 or torch.Tensor()
- local df3 = state.df3 or torch.Tensor()
+ local df1 = state.df1 or newTensor()
+ local df2 = state.df2 or newTensor()
+ local df3 = state.df3 or newTensor()
local tdf
df1:resizeAs(x)
@@ -64,13 +69,13 @@ function optim.cg(opfunc, x, config, state)
df3:resizeAs(x)
-- search direction
- local s = state.s or torch.Tensor()
+ local s = state.s or newTensor()
s:resizeAs(x)
-- we need a temp storage for X
- local x0 = state.x0 or torch.Tensor()
+ local x0 = state.x0 or newTensor()
local f0 = 0
- local df0 = state.df0 or torch.Tensor()
+ local df0 = state.df0 or newTensor()
x0:resizeAs(x)
df0:resizeAs(x)