diff options
-rw-r--r-- | Probe.lua | 62 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | nnx-1.0-1.rockspec | 1 |
3 files changed, 64 insertions, 0 deletions
diff --git a/Probe.lua b/Probe.lua new file mode 100644 index 0000000..3c93cd3 --- /dev/null +++ b/Probe.lua @@ -0,0 +1,62 @@ +local Probe, parent = torch.class('nn.Probe', 'nn.Module') + +function Probe:__init(...) + parent.__init(self) + xlua.unpack_class(self, {...}, 'nn.Probe', + 'print/display input/gradients of a network', + {arg='name', type='string', help='unique name to identify probe', req=true}, + {arg='print', type='boolean', help='print full tensor', default=false}, + {arg='display', type='boolean', help='display tensor', default=false}, + {arg='size', type='boolean', help='print tensor size', default=false}, + {arg='backw', type='boolean', help='activates probe for backward()', default=false}) +end + +function Probe:forward(input) + self.output = input + if self.size or self.content then + print('') + print('<probe::' .. self.name .. '> forward()') + if self.content then print(input) + elseif self.size then print(#input) + end + end + if self.display then + self.winf = image.display{image=input, win=self.winf} + end + return self.output +end + +function Probe:backward(input, gradOutput) + self.gradInput = gradOutput + if self.backw then + if self.size or self.content then + print('') + print('<probe::' .. self.name .. '> backward()') + if self.content then print(gradOutput) + elseif self.size then print(#gradOutput) + end + end + if self.display then + self.winb = image.display{image=gradOutput, win=self.winb} + end + end + return self.gradInput +end + +function Probe:write(file) + parent.write(self, file) + file:writeObject(self.name) + file:writeBool(self.content) + file:writeBool(self.display) + file:writeBool(self.size) + file:writeBool(self.backw) +end + +function Probe:read(file) + parent.read(self, file) + self.name = file:readObject() + self.content = file:readBool() + self.display = file:readBool() + self.size = file:readBool() + self.backw = file:readBool() +end @@ -47,6 +47,7 @@ torch.include('nnx', 'test-omp.lua') -- tools: torch.include('nnx', 'ConfusionMatrix.lua') torch.include('nnx', 'Logger.lua') +torch.include('nnx', 'Probe.lua') -- OpenMP module: torch.include('nnx', 'OmpModule.lua') diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec index 40aa155..ed70f18 100644 --- a/nnx-1.0-1.rockspec +++ b/nnx-1.0-1.rockspec @@ -58,6 +58,7 @@ build = { install_files(/lua/nnx Abs.lua) install_files(/lua/nnx ConfusionMatrix.lua) install_files(/lua/nnx Logger.lua) + install_files(/lua/nnx Probe.lua) install_files(/lua/nnx HardShrink.lua) install_files(/lua/nnx Narrow.lua) install_files(/lua/nnx Power.lua) |