Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Profile.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 36cd909cdcbc72fc7ecb8813036236248488a533 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
local ProfileModule, parent = torch.class("nn.Profile", "nn.Decorator")

function ProfileModule:__init(module, print_interval, name)
   parent.__init(self, module)
   self.print_interval = print_interval or 100
   self.name = name or torch.type(module)
   self.module = module
   self.numFwds = 0
   self.numBwds = 0
   self.summedFwdTime = 0
   self.summedBwdTime = 0
   self.timer = torch.Timer()
end

function ProfileModule:updateOutput(input)
   self.timer:reset()
   self.output = self.module:updateOutput(input)
   self.summedFwdTime = self.summedFwdTime + self.timer:time().real
   self.numFwds = self.numFwds + 1
   if self.numFwds % self.print_interval == 0 then
      print (string.format('%s took %.3f seconds for %d forward passes',
         self.name, self.summedFwdTime, self.print_interval))
      self.numFwds = 0
      self.summedFwdTime = 0
   end
   return self.output
end

function ProfileModule:updateGradInput(input, gradOutput)
   self.timer:reset()
   self.gradInput = self.module:updateGradInput(input, gradOutput)
   self.summedBwdTime = self.summedBwdTime + self.timer:time().real
   self.numBwds = self.numBwds + 1
   if self.numBwds % self.print_interval == 0 then
      print (string.format('%s took %.3f seconds for %d backward passes',
         self.name, self.summedBwdTime, self.print_interval))
      self.numBwds = 0
      self.summedBwdTime = 0
   end
   return self.gradInput
end

local function makeTorchTimerSerializable()
   -- The Timer object part of this class needs to be serializable
   -- so that the layer can be saved, cloned, etc. We add a dummy
   -- serialization of torch.Timer that just creates a new instance at read
   local timerMetatable = getmetatable(torch.Timer())
   timerMetatable['__factory'] = torch.Timer
   timerMetatable['write'] = function(object, file) end
   timerMetatable['read'] = function(object, file, versionNumber)
      return object
   end
end

makeTorchTimerSerializable()