Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-24 23:27:45 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-24 23:27:45 +0300
commit1535bd3f9320136a40a600a434b8ef3fddae0890 (patch)
treebda00fb057a1ae7cfbaff799fe627f6492351b24
parent5913e311d4741b82f12cd14447366403ccba7a98 (diff)
ZeroGrad
-rw-r--r--ZeroGrad.lua14
-rwxr-xr-xdoc/simple.md23
-rwxr-xr-xinit.lua1
-rwxr-xr-xtest.lua10
4 files changed, 47 insertions, 1 deletions
diff --git a/ZeroGrad.lua b/ZeroGrad.lua
new file mode 100644
index 0000000..7c941ce
--- /dev/null
+++ b/ZeroGrad.lua
@@ -0,0 +1,14 @@
+local ZeroGrad, parent = torch.class('nn.ZeroGrad', 'nn.Module')
+
+function ZeroGrad:updateOutput(input)
+ self.output:set(input)
+ return self.output
+end
+
+-- the gradient is simply zeroed.
+-- useful when you don't want to backpropgate through certain paths.
+function ZeroGrad:updateGradInput(input, gradOutput)
+ self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input)
+ self.gradInput = nn.utils.recursiveFill(self.gradInput, 0)
+ return self.gradInput
+end
diff --git a/doc/simple.md b/doc/simple.md
index de2f46d..5e31080 100755
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -61,6 +61,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [WhiteNoise](#nn.WhiteNoise) : adds isotropic Gaussian noise to the signal when in training mode;
* [OneHot](#nn.OneHot) : transforms a tensor of indices into [one-hot](https://en.wikipedia.org/wiki/One-hot) encoding;
* [PrintSize](#nn.PrintSize) : prints the size of `input` and `gradOutput` (useful for debugging);
+ * [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`.
<a name="nn.Linear"></a>
## Linear ##
@@ -1762,4 +1763,24 @@ module = nn.PrintSize(name)
This module is useful for debugging complicated module composites.
It prints the size of the `input` and `gradOutput` during `forward`
and `backward` propagation respectively.
-The `name` is a string used to identify the module along side the printed size. \ No newline at end of file
+The `name` is a string used to identify the module along side the printed size.
+
+<a name='nn.ZeroGrad'></a>
+## ZeroGrad ##
+
+```lua
+module = nn.ZeroGrad()
+input = torch.Tensor{1,2}
+gradOutput = torch.Tensor{3,4}
+print(module:forward(input))
+ 1
+ 2
+[torch.DoubleTensor of size 2]
+
+print(module:backward(input, gradOutput))
+ 0
+ 0
+[torch.DoubleTensor of size 2]
+```
+
+The module zeros the `gradInput` but forwards the `input` as-is. \ No newline at end of file
diff --git a/init.lua b/init.lua
index ac7396f..b397d77 100755
--- a/init.lua
+++ b/init.lua
@@ -63,6 +63,7 @@ require('nn.VolumetricDropout')
require('nn.WhiteNoise')
require('nn.OneHot')
require('nn.PrintSize')
+require('nn.ZeroGrad')
require('nn.CAddTable')
require('nn.CDivTable')
diff --git a/test.lua b/test.lua
index 9edc716..2dafb09 100755
--- a/test.lua
+++ b/test.lua
@@ -8538,6 +8538,16 @@ function nntest.OneHot()
end
end
+function nntest.ZeroGrad()
+ local input = torch.randn(3,4)
+ local zg = nn.ZeroGrad()
+ local output = zg:forward(input)
+ mytester:assertTensorEq(input, output, 0.00000001)
+ local gradInput = zg:backward(input, input)
+ local gradInput2 = gradInput:clone():zero()
+ mytester:assertTensorEq(gradInput, gradInput2, 0.0000001)
+end
+
mytester:add(nntest)