Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Probe.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: ea8527cb09ad9c2446ab668aee963b3da21dd7e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
local Probe, parent = torch.class('nn.Probe', 'nn.Module')

function Probe:__init(...)
   parent.__init(self)
   xlua.unpack_class(self, {...}, 'nn.Probe', 
                     'print/display input/gradients of a network',
                     {arg='name', type='string', help='unique name to identify probe', req=true},
                     {arg='print', type='boolean', help='print full tensor', default=false},
                     {arg='display', type='boolean', help='display tensor', default=false},
                     {arg='size', type='boolean', help='print tensor size', default=false},
                     {arg='backw', type='boolean', help='activates probe for backward()', default=false})
end

function Probe:updateOutput(input)
   self.output = input
   if self.size or self.content then
      print('')
      print('<probe::' .. self.name .. '> updateOutput()')
      if self.content then print(input)
      elseif self.size then print(#input)
      end
   end
   if self.display then
      self.winf = image.display{image=input, win=self.winf}
   end
   return self.output
end

function Probe:updateGradInput(input, gradOutput)
   self.gradInput = gradOutput
   if self.backw then
      if self.size or self.content then
         print('')
         print('<probe::' .. self.name .. '> updateGradInput()')
         if self.content then print(gradOutput)
         elseif self.size then print(#gradOutput)
         end
      end
      if self.display then
         self.winb = image.display{image=gradOutput, win=self.winb}
      end
   end
   return self.gradInput
end