Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Léonard <nick@nikopia.org>2014-07-10 00:37:50 +0400
committerNicholas Léonard <nick@nikopia.org>2014-07-10 00:37:50 +0400
commitb16de53afa14cb19b2fd229105500298eed455e6 (patch)
treebb587b74463dd187749479b7f6b88f85c3cfb0e0
parente197e2f0acf3cbf2cb09765d8fcefa968fea795f (diff)
Updated simple.md Dropout doc to v2 of Dropout.
-rw-r--r--doc/simple.md43
1 files changed, 25 insertions, 18 deletions
diff --git a/doc/simple.md b/doc/simple.md
index a708da5..9a5543d 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -92,21 +92,24 @@ commensurate output element be zero. This has proven an effective technique for
regularization and preventing the co-adaptation of neurons
(see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)).
+Furthermore, the ouputs are scaled by a factor of `1/(1-p)` during training. This allows the
+`input` to be simply forwarded as-is during evaluation.
+
In this example, we demonstrate how the call to [forward](module.md#output-forwardinput) samples
-different `outputs` given the same `input`:
+different `outputs` to dropout (the zeros) given the same `input`:
```lua
module = nn.Dropout()
> x=torch.Tensor{{1,2,3,4},{5,6,7,8}}
> =module:forward(x)
- 0 2 0 4
- 5 0 0 0
+ 2 0 0 8
+ 10 0 14 0
[torch.DoubleTensor of dimension 2x4]
> =module:forward(x)
- 0 2 3 4
- 5 6 7 8
+ 0 0 6 0
+ 10 0 0 0
[torch.DoubleTensor of dimension 2x4]
```
@@ -114,35 +117,39 @@ module = nn.Dropout()
[Backward](module.md#gradinput-backwardinput-gradoutput) drops out the gradients at the same location:
```lua
> =module:forward(x)
- 1 2 0 0
- 5 6 7 0
+ 0 4 0 0
+ 10 12 0 16
[torch.DoubleTensor of dimension 2x4]
-> return module:backward(x,x:clone():fill(1))
- 1 1 0 0
- 1 1 1 0
+> =module:backward(x,x:clone():fill(1))
+ 0 2 0 0
+ 2 2 0 2
[torch.DoubleTensor of dimension 2x4]
+
```
+In both cases the `gradOutput` and `input` are scaled by `1/(1-p)`, which in this case is `2`.
-During [evaluation](module.md#evaluate), `Dropout` does nothing more than scales the input by `1-p` such that
-all elements of the input are considered.
+During [evaluation](module.md#evaluate), `Dropout` does nothing more than
+forward the input such that all elements of the input are considered.
```lua
> module:evaluate()
-> return module:forward(x)
- 0.5000 1.0000 1.5000 2.0000
- 2.5000 3.0000 3.5000 4.0000
+> module:forward(x)
+ 1 2 3 4
+ 5 6 7 8
[torch.DoubleTensor of dimension 2x4]
+
```
We can return to training our model by first calling [Module:training()](module.md#training):
```lua
> module:training()
-> module:forward(x)
- 0 0 0 4
- 5 6 0 0
+> return module:forward(x)
+ 2 4 6 0
+ 0 0 0 16
[torch.DoubleTensor of dimension 2x4]
+
```
When used, `Dropout` should normally be applied to the input of parameterized