Welcome to mirror list, hosted at ThFree Co, Russian Federation.

ZeroGrad.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 446b25d14bf894cc9b291783d329b13e3540fe54 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
local ZeroGrad, parent
if nn.ZeroGrad then -- prevent name conflicts with rnn
   ZeroGrad, parent = nn.ZeroGrad, nn.Module
else
   ZeroGrad, parent = torch.class('nn.ZeroGrad', 'nn.Module')
end

local function recursiveZero(t1,t2)
   if torch.type(t2) == 'table' then
      t1 = (torch.type(t1) == 'table') and t1 or {t1}
      for key,_ in pairs(t2) do
         t1[key], t2[key] = recursiveZero(t1[key], t2[key])
      end
   elseif torch.isTensor(t2) then
      t1 = t1 or t2.new()
      t1:resizeAs(t2):zero()
   else
      error("expecting nested tensors or tables. Got "..
            torch.type(t1).." and "..torch.type(t2).." instead")
   end
   return t1, t2
end

function ZeroGrad:updateOutput(input)
   self.output:set(input)
   return self.output
end

-- the gradient is simply zeroed.
-- useful when you don't want to backpropgate through certain paths.
function ZeroGrad:updateGradInput(input, gradOutput)
   self.gradInput = recursiveZero(self.gradInput, gradOutput)
end