Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Probe.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: af6f2923b5856c1bcd512c636c411c8da2c36352 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
local Probe, parent = torch.class('nn.Probe', 'nn.Module')

function Probe:__init(name,display)
   parent.__init(self)
   self.name = name or 'unnamed'
   self.display = display
   nn._ProbeTimer = nn._ProbeTimer or torch.Timer()
end

function Probe:updateOutput(input)
   self.output = input
   local legend = '<' .. self.name .. '>.output'
   local size = {}
   for i = 1,input:dim() do
      size[i] = input:size(i)
   end
   size = table.concat(size,'x')
   local diff = nn._ProbeTimer:time().real - (nn._ProbeLast or 0)
   nn._ProbeLast = nn._ProbeTimer:time().real
   print('')
   print(legend)
   print('  + size = ' .. size)
   print('  + mean = ' .. input:mean())
   print('  + std = ' .. input:std())
   print('  + min = ' .. input:min())
   print('  + max = ' .. input:max())
   print('  + time since last probe = ' .. string.format('%0.1f',diff*1000) .. 'ms')
   if self.display then
      self.winf = image.display{image=input, win=self.winf, legend=legend}
   end
   return self.output
end

function Probe:updateGradInput(input, gradOutput)
   self.gradInput = gradOutput
   local legend = 'layer<' .. self.name .. '>.gradInput'
   local size = {}
   for i = 1,gradOutput:dim() do
      size[i] = gradOutput:size(i)
   end
   size = table.concat(size,'x')
   local diff = nn._ProbeTimer:time().real - (nn._ProbeLast or 0)
   nn._ProbeLast = nn._ProbeTimer:time().real
   print('')
   print(legend)
   print('  + size = ' .. size)
   print('  + mean = ' .. gradOutput:mean())
   print('  + std = ' .. gradOutput:std())
   print('  + min = ' .. gradOutput:min())
   print('  + max = ' .. gradOutput:max())
   print('  + time since last probe = ' .. string.format('%0.1f',diff*1000) .. 'ms')
   if self.display then
      self.winb = image.display{image=gradOutput, win=self.winb, legend=legend}
   end
   return self.gradInput
end