diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-08-25 02:04:29 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-25 02:04:29 +0300 |
commit | afd54e6fc560669f5f541f687e00a9ddeab075cc (patch) | |
tree | 488ac80d1a7bab87b273026832bd35faca8a5599 | |
parent | 9c7fc422356a5a50adb15961aa27929371c2ea01 (diff) | |
parent | 84501e325f1cf7adf64b10a0b2a8aeb3a0928b32 (diff) |
Merge pull request #127 from gcinbis/patch-2
Copy C1 value, in case it is a Tensor reference
-rw-r--r-- | checkgrad.lua | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/checkgrad.lua b/checkgrad.lua index 402d9fc..908a1a2 100644 --- a/checkgrad.lua +++ b/checkgrad.lua @@ -20,9 +20,15 @@ RETURN: function optim.checkgrad(opfunc, x, eps) -- compute true gradient: - local _,dC = opfunc(x) + local Corg,dC = opfunc(x) dC:resize(x:size()) + local Ctmp -- temporary value + local isTensor = torch.isTensor(Corg) + if isTensor then + Ctmp = Corg.new(Corg:size()) + end + -- compute numeric approximations to gradient: local eps = eps or 1e-7 local dC_est = torch.Tensor():typeAs(dC):resizeAs(dC) @@ -30,6 +36,10 @@ function optim.checkgrad(opfunc, x, eps) local tmp = x[i] x[i] = x[i] + eps local C1 = opfunc(x) + if isTensor then + Ctmp:copy(C1) + C1 = Ctmp + end x[i] = x[i] - 2 * eps local C2 = opfunc(x) x[i] = tmp |