Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Pfau <pfau@google.com>2015-05-19 18:19:55 +0300
committerDavid Pfau <pfau@google.com>2015-05-19 18:19:55 +0300
commit0b78f62128677cee3ff8288481f10a1be8442abd (patch)
treec46fcb1a06b1543808897972c1b5decb80dba07f
parenta329ac762524c78c8eec9dc0c1aa1ab480e7c94c (diff)
Replaced CUDA flag and newTensor() with x.new()
-rw-r--r--cg.lua17
1 files changed, 6 insertions, 11 deletions
diff --git a/cg.lua b/cg.lua
index 990fed5..842a7d5 100644
--- a/cg.lua
+++ b/cg.lua
@@ -42,11 +42,6 @@ function optim.cg(opfunc, x, config, state)
local ratio = config.ratio or 100
local maxEval = config.maxEval or maxIter*1.25
local red = 1
- local cuda = torch.typename(x) == 'torch.CudaTensor'
- local function newTensor()
- if cuda then return torch.CudaTensor() end
- return torch.Tensor()
- end
local verbose = config.verbose or 0
@@ -59,9 +54,9 @@ function optim.cg(opfunc, x, config, state)
local d1,d2,d3 = 0,0,0
local f1,f2,f3 = 0,0,0
- local df1 = state.df1 or newTensor()
- local df2 = state.df2 or newTensor()
- local df3 = state.df3 or newTensor()
+ local df1 = state.df1 or x.new()
+ local df2 = state.df2 or x.new()
+ local df3 = state.df3 or x.new()
local tdf
df1:resizeAs(x)
@@ -69,13 +64,13 @@ function optim.cg(opfunc, x, config, state)
df3:resizeAs(x)
-- search direction
- local s = state.s or newTensor()
+ local s = state.s or x.new()
s:resizeAs(x)
-- we need a temp storage for X
- local x0 = state.x0 or newTensor()
+ local x0 = state.x0 or x.new()
local f0 = 0
- local df0 = state.df0 or newTensor()
+ local df0 = state.df0 or x.new()
x0:resizeAs(x)
df0:resizeAs(x)