Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas LĂ©onard <nick@nikopia.org>2017-05-24 23:01:04 +0300
committerGitHub <noreply@github.com>2017-05-24 23:01:04 +0300
commit5913e311d4741b82f12cd14447366403ccba7a98 (patch)
tree7baf4185af0ff59e677a0b096d290bd915f2d17f
parentf66ca966595af65dcf84aa7173b6bba597958fff (diff)
parentbbcf8295c75ef860576781b3cb45bec73d400f4b (diff)
Merge pull request #1224 from nicholas-leonard/PrintSize
PrintSize
-rw-r--r--PrintSize.lua36
-rwxr-xr-xdoc/simple.md13
-rwxr-xr-xinit.lua1
3 files changed, 50 insertions, 0 deletions
diff --git a/PrintSize.lua b/PrintSize.lua
new file mode 100644
index 0000000..d8dc91b
--- /dev/null
+++ b/PrintSize.lua
@@ -0,0 +1,36 @@
+local PrintSize, parent = torch.class('nn.PrintSize', 'nn.Module')
+
+function PrintSize:__init(prefix)
+ parent.__init(self)
+ self.prefix = prefix or "PrintSize"
+end
+
+function PrintSize:updateOutput(input)
+ self.output = input
+ local size
+ if torch.type(input) == 'table' then
+ size = input
+ elseif torch.type(input) == 'nil' then
+ size = 'missing size'
+ else
+ size = input:size()
+ end
+ print(self.prefix..":input\n", size)
+ return self.output
+end
+
+
+function PrintSize:updateGradInput(input, gradOutput)
+ local size
+ if torch.type(gradOutput) == 'table' then
+ size = gradOutput
+ elseif torch.type(gradOutput) == 'nil' then
+ size = 'missing size'
+ else
+ size = gradOutput:size()
+ end
+ print(self.prefix..":gradOutput\n", size)
+ self.gradInput = gradOutput
+ return self.gradInput
+end
+
diff --git a/doc/simple.md b/doc/simple.md
index 25d4063..de2f46d 100755
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -60,6 +60,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [Constant](#nn.Constant) : outputs a constant value given an input (which is ignored);
* [WhiteNoise](#nn.WhiteNoise) : adds isotropic Gaussian noise to the signal when in training mode;
* [OneHot](#nn.OneHot) : transforms a tensor of indices into [one-hot](https://en.wikipedia.org/wiki/One-hot) encoding;
+ * [PrintSize](#nn.PrintSize) : prints the size of `input` and `gradOutput` (useful for debugging);
<a name="nn.Linear"></a>
## Linear ##
@@ -1750,3 +1751,15 @@ oh:forward(torch.Tensor{{3,2,1},{1,2,3}})
0 0 1 0 0
[torch.DoubleTensor of size 2x3x5]
```
+
+<a name='nn.PrintSize'></a>
+## PrintSize ##
+
+```lua
+module = nn.PrintSize(name)
+```
+
+This module is useful for debugging complicated module composites.
+It prints the size of the `input` and `gradOutput` during `forward`
+and `backward` propagation respectively.
+The `name` is a string used to identify the module along side the printed size. \ No newline at end of file
diff --git a/init.lua b/init.lua
index 3b27c0a..ac7396f 100755
--- a/init.lua
+++ b/init.lua
@@ -62,6 +62,7 @@ require('nn.SpatialDropout')
require('nn.VolumetricDropout')
require('nn.WhiteNoise')
require('nn.OneHot')
+require('nn.PrintSize')
require('nn.CAddTable')
require('nn.CDivTable')