diff options
author | Nicholas LĂ©onard <nick@nikopia.org> | 2017-05-24 23:01:04 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-24 23:01:04 +0300 |
commit | 5913e311d4741b82f12cd14447366403ccba7a98 (patch) | |
tree | 7baf4185af0ff59e677a0b096d290bd915f2d17f | |
parent | f66ca966595af65dcf84aa7173b6bba597958fff (diff) | |
parent | bbcf8295c75ef860576781b3cb45bec73d400f4b (diff) |
Merge pull request #1224 from nicholas-leonard/PrintSize
PrintSize
-rw-r--r-- | PrintSize.lua | 36 | ||||
-rwxr-xr-x | doc/simple.md | 13 | ||||
-rwxr-xr-x | init.lua | 1 |
3 files changed, 50 insertions, 0 deletions
diff --git a/PrintSize.lua b/PrintSize.lua new file mode 100644 index 0000000..d8dc91b --- /dev/null +++ b/PrintSize.lua @@ -0,0 +1,36 @@ +local PrintSize, parent = torch.class('nn.PrintSize', 'nn.Module') + +function PrintSize:__init(prefix) + parent.__init(self) + self.prefix = prefix or "PrintSize" +end + +function PrintSize:updateOutput(input) + self.output = input + local size + if torch.type(input) == 'table' then + size = input + elseif torch.type(input) == 'nil' then + size = 'missing size' + else + size = input:size() + end + print(self.prefix..":input\n", size) + return self.output +end + + +function PrintSize:updateGradInput(input, gradOutput) + local size + if torch.type(gradOutput) == 'table' then + size = gradOutput + elseif torch.type(gradOutput) == 'nil' then + size = 'missing size' + else + size = gradOutput:size() + end + print(self.prefix..":gradOutput\n", size) + self.gradInput = gradOutput + return self.gradInput +end + diff --git a/doc/simple.md b/doc/simple.md index 25d4063..de2f46d 100755 --- a/doc/simple.md +++ b/doc/simple.md @@ -60,6 +60,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi * [Constant](#nn.Constant) : outputs a constant value given an input (which is ignored); * [WhiteNoise](#nn.WhiteNoise) : adds isotropic Gaussian noise to the signal when in training mode; * [OneHot](#nn.OneHot) : transforms a tensor of indices into [one-hot](https://en.wikipedia.org/wiki/One-hot) encoding; + * [PrintSize](#nn.PrintSize) : prints the size of `input` and `gradOutput` (useful for debugging); <a name="nn.Linear"></a> ## Linear ## @@ -1750,3 +1751,15 @@ oh:forward(torch.Tensor{{3,2,1},{1,2,3}}) 0 0 1 0 0 [torch.DoubleTensor of size 2x3x5] ``` + +<a name='nn.PrintSize'></a> +## PrintSize ## + +```lua +module = nn.PrintSize(name) +``` + +This module is useful for debugging complicated module composites. +It prints the size of the `input` and `gradOutput` during `forward` +and `backward` propagation respectively. +The `name` is a string used to identify the module along side the printed size.
\ No newline at end of file @@ -62,6 +62,7 @@ require('nn.SpatialDropout') require('nn.VolumetricDropout') require('nn.WhiteNoise') require('nn.OneHot') +require('nn.PrintSize') require('nn.CAddTable') require('nn.CDivTable') |