Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-08-25 02:04:29 +0300
committerGitHub <noreply@github.com>2016-08-25 02:04:29 +0300
commitafd54e6fc560669f5f541f687e00a9ddeab075cc (patch)
tree488ac80d1a7bab87b273026832bd35faca8a5599
parent9c7fc422356a5a50adb15961aa27929371c2ea01 (diff)
parent84501e325f1cf7adf64b10a0b2a8aeb3a0928b32 (diff)
Merge pull request #127 from gcinbis/patch-2
Copy C1 value, in case it is a Tensor reference
-rw-r--r--checkgrad.lua12
1 files changed, 11 insertions, 1 deletions
diff --git a/checkgrad.lua b/checkgrad.lua
index 402d9fc..908a1a2 100644
--- a/checkgrad.lua
+++ b/checkgrad.lua
@@ -20,9 +20,15 @@ RETURN:
function optim.checkgrad(opfunc, x, eps)
-- compute true gradient:
- local _,dC = opfunc(x)
+ local Corg,dC = opfunc(x)
dC:resize(x:size())
+ local Ctmp -- temporary value
+ local isTensor = torch.isTensor(Corg)
+ if isTensor then
+ Ctmp = Corg.new(Corg:size())
+ end
+
-- compute numeric approximations to gradient:
local eps = eps or 1e-7
local dC_est = torch.Tensor():typeAs(dC):resizeAs(dC)
@@ -30,6 +36,10 @@ function optim.checkgrad(opfunc, x, eps)
local tmp = x[i]
x[i] = x[i] + eps
local C1 = opfunc(x)
+ if isTensor then
+ Ctmp:copy(C1)
+ C1 = Ctmp
+ end
x[i] = x[i] - 2 * eps
local C2 = opfunc(x)
x[i] = tmp