diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-05-24 23:41:54 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-05-24 23:41:54 +0300 |
commit | cf81236bfe2ffa1ba1c85e24deffb0cef8ecbccc (patch) | |
tree | 925845cc5d7dc5c6dfa00096c454e1999b5c9bc1 | |
parent | 2241454174357681f40f0ab236922668a2bd36fb (diff) |
moved ZeroGrad to nnx
-rw-r--r-- | ZeroGrad.lua | 33 | ||||
-rw-r--r-- | init.lua | 1 |
2 files changed, 0 insertions, 34 deletions
diff --git a/ZeroGrad.lua b/ZeroGrad.lua deleted file mode 100644 index 446b25d..0000000 --- a/ZeroGrad.lua +++ /dev/null @@ -1,33 +0,0 @@ -local ZeroGrad, parent -if nn.ZeroGrad then -- prevent name conflicts with rnn - ZeroGrad, parent = nn.ZeroGrad, nn.Module -else - ZeroGrad, parent = torch.class('nn.ZeroGrad', 'nn.Module') -end - -local function recursiveZero(t1,t2) - if torch.type(t2) == 'table' then - t1 = (torch.type(t1) == 'table') and t1 or {t1} - for key,_ in pairs(t2) do - t1[key], t2[key] = recursiveZero(t1[key], t2[key]) - end - elseif torch.isTensor(t2) then - t1 = t1 or t2.new() - t1:resizeAs(t2):zero() - else - error("expecting nested tensors or tables. Got ".. - torch.type(t1).." and "..torch.type(t2).." instead") - end - return t1, t2 -end - -function ZeroGrad:updateOutput(input) - self.output:set(input) - return self.output -end - --- the gradient is simply zeroed. --- useful when you don't want to backpropgate through certain paths. -function ZeroGrad:updateGradInput(input, gradOutput) - self.gradInput = recursiveZero(self.gradInput, gradOutput) -end @@ -75,7 +75,6 @@ require('nnx.MultiSoftMax') require('nnx.Balance') require('nnx.PushTable') require('nnx.PullTable') -require('nnx.ZeroGrad') require('nnx.QDRiemaNNLinear') -- criterions: |