diff options
author | Nicholas LĂ©onard <nick@nikopia.org> | 2017-05-24 23:46:23 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-24 23:46:23 +0300 |
commit | e40e2816e23cebc85fd5733e716e903a2d02c175 (patch) | |
tree | bda00fb057a1ae7cfbaff799fe627f6492351b24 | |
parent | 5913e311d4741b82f12cd14447366403ccba7a98 (diff) | |
parent | 1535bd3f9320136a40a600a434b8ef3fddae0890 (diff) |
Merge pull request #1225 from nicholas-leonard/ZeroGrad
nn.ZeroGrad
-rw-r--r-- | ZeroGrad.lua | 14 | ||||
-rwxr-xr-x | doc/simple.md | 23 | ||||
-rwxr-xr-x | init.lua | 1 | ||||
-rwxr-xr-x | test.lua | 10 |
4 files changed, 47 insertions, 1 deletions
diff --git a/ZeroGrad.lua b/ZeroGrad.lua new file mode 100644 index 0000000..7c941ce --- /dev/null +++ b/ZeroGrad.lua @@ -0,0 +1,14 @@ +local ZeroGrad, parent = torch.class('nn.ZeroGrad', 'nn.Module') + +function ZeroGrad:updateOutput(input) + self.output:set(input) + return self.output +end + +-- the gradient is simply zeroed. +-- useful when you don't want to backpropgate through certain paths. +function ZeroGrad:updateGradInput(input, gradOutput) + self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input) + self.gradInput = nn.utils.recursiveFill(self.gradInput, 0) + return self.gradInput +end diff --git a/doc/simple.md b/doc/simple.md index de2f46d..5e31080 100755 --- a/doc/simple.md +++ b/doc/simple.md @@ -61,6 +61,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi * [WhiteNoise](#nn.WhiteNoise) : adds isotropic Gaussian noise to the signal when in training mode; * [OneHot](#nn.OneHot) : transforms a tensor of indices into [one-hot](https://en.wikipedia.org/wiki/One-hot) encoding; * [PrintSize](#nn.PrintSize) : prints the size of `input` and `gradOutput` (useful for debugging); + * [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`. <a name="nn.Linear"></a> ## Linear ## @@ -1762,4 +1763,24 @@ module = nn.PrintSize(name) This module is useful for debugging complicated module composites. It prints the size of the `input` and `gradOutput` during `forward` and `backward` propagation respectively. -The `name` is a string used to identify the module along side the printed size.
\ No newline at end of file +The `name` is a string used to identify the module along side the printed size. + +<a name='nn.ZeroGrad'></a> +## ZeroGrad ## + +```lua +module = nn.ZeroGrad() +input = torch.Tensor{1,2} +gradOutput = torch.Tensor{3,4} +print(module:forward(input)) + 1 + 2 +[torch.DoubleTensor of size 2] + +print(module:backward(input, gradOutput)) + 0 + 0 +[torch.DoubleTensor of size 2] +``` + +The module zeros the `gradInput` but forwards the `input` as-is.
\ No newline at end of file @@ -63,6 +63,7 @@ require('nn.VolumetricDropout') require('nn.WhiteNoise') require('nn.OneHot') require('nn.PrintSize') +require('nn.ZeroGrad') require('nn.CAddTable') require('nn.CDivTable') @@ -8538,6 +8538,16 @@ function nntest.OneHot() end end +function nntest.ZeroGrad() + local input = torch.randn(3,4) + local zg = nn.ZeroGrad() + local output = zg:forward(input) + mytester:assertTensorEq(input, output, 0.00000001) + local gradInput = zg:backward(input, input) + local gradInput2 = gradInput:clone():zero() + mytester:assertTensorEq(gradInput, gradInput2, 0.0000001) +end + mytester:add(nntest) |