Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Probe.lua62
-rw-r--r--init.lua1
-rw-r--r--nnx-1.0-1.rockspec1
3 files changed, 64 insertions, 0 deletions
diff --git a/Probe.lua b/Probe.lua
new file mode 100644
index 0000000..3c93cd3
--- /dev/null
+++ b/Probe.lua
@@ -0,0 +1,62 @@
+local Probe, parent = torch.class('nn.Probe', 'nn.Module')
+
+function Probe:__init(...)
+ parent.__init(self)
+ xlua.unpack_class(self, {...}, 'nn.Probe',
+ 'print/display input/gradients of a network',
+ {arg='name', type='string', help='unique name to identify probe', req=true},
+ {arg='print', type='boolean', help='print full tensor', default=false},
+ {arg='display', type='boolean', help='display tensor', default=false},
+ {arg='size', type='boolean', help='print tensor size', default=false},
+ {arg='backw', type='boolean', help='activates probe for backward()', default=false})
+end
+
+function Probe:forward(input)
+ self.output = input
+ if self.size or self.content then
+ print('')
+ print('<probe::' .. self.name .. '> forward()')
+ if self.content then print(input)
+ elseif self.size then print(#input)
+ end
+ end
+ if self.display then
+ self.winf = image.display{image=input, win=self.winf}
+ end
+ return self.output
+end
+
+function Probe:backward(input, gradOutput)
+ self.gradInput = gradOutput
+ if self.backw then
+ if self.size or self.content then
+ print('')
+ print('<probe::' .. self.name .. '> backward()')
+ if self.content then print(gradOutput)
+ elseif self.size then print(#gradOutput)
+ end
+ end
+ if self.display then
+ self.winb = image.display{image=gradOutput, win=self.winb}
+ end
+ end
+ return self.gradInput
+end
+
+function Probe:write(file)
+ parent.write(self, file)
+ file:writeObject(self.name)
+ file:writeBool(self.content)
+ file:writeBool(self.display)
+ file:writeBool(self.size)
+ file:writeBool(self.backw)
+end
+
+function Probe:read(file)
+ parent.read(self, file)
+ self.name = file:readObject()
+ self.content = file:readBool()
+ self.display = file:readBool()
+ self.size = file:readBool()
+ self.backw = file:readBool()
+end
diff --git a/init.lua b/init.lua
index 27dfbfd..2d62899 100644
--- a/init.lua
+++ b/init.lua
@@ -47,6 +47,7 @@ torch.include('nnx', 'test-omp.lua')
-- tools:
torch.include('nnx', 'ConfusionMatrix.lua')
torch.include('nnx', 'Logger.lua')
+torch.include('nnx', 'Probe.lua')
-- OpenMP module:
torch.include('nnx', 'OmpModule.lua')
diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec
index 40aa155..ed70f18 100644
--- a/nnx-1.0-1.rockspec
+++ b/nnx-1.0-1.rockspec
@@ -58,6 +58,7 @@ build = {
install_files(/lua/nnx Abs.lua)
install_files(/lua/nnx ConfusionMatrix.lua)
install_files(/lua/nnx Logger.lua)
+ install_files(/lua/nnx Probe.lua)
install_files(/lua/nnx HardShrink.lua)
install_files(/lua/nnx Narrow.lua)
install_files(/lua/nnx Power.lua)