Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-07-10 18:28:55 +0400
committerSoumith Chintala <soumith@gmail.com>2014-07-10 18:28:55 +0400
commit1625eef2d4448c75711c7a87811b8a9e81a9ac40 (patch)
treed4c9f0b38c1cba20fcf1178ec41c3736b3b40fbf
parent75a2279ef3dac76046f128a2d77e1ffd2dcd5397 (diff)
parentca59595974936b9ff42377677e7553627f397197 (diff)
Merge github.com:torch/nn into concattable_tableinput
-rw-r--r--Dropout.lua46
-rw-r--r--ElementTable.lua34
-rw-r--r--ReLU.lua5
-rw-r--r--Sequential.lua10
-rw-r--r--doc/image/relu.pngbin0 -> 19636 bytes
-rw-r--r--doc/module.md8
-rw-r--r--doc/simple.md83
-rw-r--r--doc/table.md95
-rwxr-xr-xdoc/transfer.md17
-rw-r--r--init.lua3
-rw-r--r--test/test.lua73
11 files changed, 374 insertions, 0 deletions
diff --git a/Dropout.lua b/Dropout.lua
new file mode 100644
index 0000000..a92faf2
--- /dev/null
+++ b/Dropout.lua
@@ -0,0 +1,46 @@
+local Dropout, Parent = torch.class('nn.Dropout', 'nn.Module')
+
+function Dropout:__init(p,v1)
+ Parent.__init(self)
+ self.p = p or 0.5
+ self.train = true
+ -- version 2 scales output during training instead of evaluation
+ self.v2 = not v1
+ if self.p >= 1 or self.p < 0 then
+ error('<Dropout> illegal percentage, must be 0 <= p < 1')
+ end
+ self.noise = torch.Tensor()
+ self.fnoise = torch.Tensor()
+end
+
+function Dropout:updateOutput(input)
+ self.output:resizeAs(input):copy(input)
+ if self.train then
+ self.fnoise = self.fnoise:float()
+ self.fnoise:resize(input:size())
+ self.noise:resizeAs(input)
+ self.fnoise:bernoulli(1-self.p)
+ self.noise:copy(self.fnoise)
+ if self.v2 then
+ self.noise:div(1-self.p)
+ end
+ self.output:cmul(self.noise)
+ elseif not self.v2 then
+ self.output:mul(1-self.p)
+ end
+ return self.output
+end
+
+function Dropout:updateGradInput(input, gradOutput)
+ if self.train then
+ self.gradInput:resizeAs(gradOutput):copy(gradOutput)
+ self.gradInput:cmul(self.noise) -- simply mask the gradients with the noise vector
+ else
+ error('backprop only defined while training')
+ end
+ return self.gradInput
+end
+
+function Dropout:setp(p)
+ self.p = p
+end
diff --git a/ElementTable.lua b/ElementTable.lua
new file mode 100644
index 0000000..cb3ff0f
--- /dev/null
+++ b/ElementTable.lua
@@ -0,0 +1,34 @@
+local ElementTable, parent = torch.class('nn.ElementTable', 'nn.Module')
+
+function ElementTable:__init(index)
+ parent.__init(self)
+ self.index = index
+ self.gradInput = {}
+end
+
+function ElementTable:updateOutput(input)
+ self.output = input[self.index]
+ return self.output
+end
+
+function ElementTable:updateGradInput(input, gradOutput)
+ if #self.gradInput == 0 then
+ local function zeroTableCopy(t1, t2)
+ for k, v in pairs(t2) do
+ if (torch.type(v) == "table") then
+ t1[k] = zeroTableCopy(t1[k] or {}, t2[k])
+ else
+ t1[k] = v:clone():zero()
+ end
+ end
+ return t1
+ end
+ zeroTableCopy(self.gradInput, input)
+ end
+ self.gradInput[self.index] = gradOutput
+ return self.gradInput
+end
+
+function ElementTable:type(type)
+ self.gradInput = {}
+end
diff --git a/ReLU.lua b/ReLU.lua
new file mode 100644
index 0000000..65bf196
--- /dev/null
+++ b/ReLU.lua
@@ -0,0 +1,5 @@
+local ReLU, Parent = torch.class('nn.ReLU', 'nn.Threshold')
+
+function ReLU:__init(p)
+ Parent.__init(self,0,0)
+end
diff --git a/Sequential.lua b/Sequential.lua
index b43bd99..ec3247b 100644
--- a/Sequential.lua
+++ b/Sequential.lua
@@ -13,6 +13,16 @@ function Sequential:add(module)
return self
end
+function Sequential:insert(module, index)
+ index = index or (#self.modules + 1)
+ if index > (#self.modules + 1) then
+ error"index should be contiguous to existing modules"
+ end
+ table.insert(self.modules, index, module)
+ self.output = self.modules[#self.modules].output
+ self.gradInput = self.modules[1].gradInput
+end
+
function Sequential:size()
return #self.modules
end
diff --git a/doc/image/relu.png b/doc/image/relu.png
new file mode 100644
index 0000000..d60d2ab
--- /dev/null
+++ b/doc/image/relu.png
Binary files differ
diff --git a/doc/module.md b/doc/module.md
index c8bf501..273fec2 100644
--- a/doc/module.md
+++ b/doc/module.md
@@ -274,3 +274,11 @@ Custom modules should not override this function. They should instead override [
This function will go over all the weights and gradWeights and make them view into a single tensor (one for weights and one for gradWeights). Since the storage of every weight and gradWeight is changed, this function should be called only once on a given network.
+<a name="nn.Module.training"/>
+### training() ###
+This sets the mode of the Module (or sub-modules) to `train=true`. This is useful for modules like [Dropout](simple.md#nn.Dropout) that have a different behaviour during training vs evaluation.
+
+<a name="nn.Module.evaluate"/>
+### evaluate() ###
+This sets the mode of the Module (or sub-modules) to `train=false`. This is useful for modules like [Dropout](simple.md#nn.Dropout) that have a different behaviour during training vs evaluation.
+
diff --git a/doc/simple.md b/doc/simple.md
index ad883b7..9a5543d 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -79,6 +79,89 @@ values in a a vector where all other elements are zeros. The
indices should not exceed the stated dimensions of the input to the
layer (10000 in the example).
+
+<a name="nn.Dropout"/>
+## Dropout ##
+
+`module` = `nn.Dropout(p)`
+
+During training, `Dropout` masks parts of the `input` using binary samples from
+a [bernoulli](http://en.wikipedia.org/wiki/Bernoulli_distribution) distribution.
+Each `input` element has a probability of `p` of being dropped, i.e having its
+commensurate output element be zero. This has proven an effective technique for
+regularization and preventing the co-adaptation of neurons
+(see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)).
+
+Furthermore, the ouputs are scaled by a factor of `1/(1-p)` during training. This allows the
+`input` to be simply forwarded as-is during evaluation.
+
+In this example, we demonstrate how the call to [forward](module.md#output-forwardinput) samples
+different `outputs` to dropout (the zeros) given the same `input`:
+```lua
+module = nn.Dropout()
+
+> x=torch.Tensor{{1,2,3,4},{5,6,7,8}}
+
+> =module:forward(x)
+ 2 0 0 8
+ 10 0 14 0
+[torch.DoubleTensor of dimension 2x4]
+
+> =module:forward(x)
+ 0 0 6 0
+ 10 0 0 0
+[torch.DoubleTensor of dimension 2x4]
+
+```
+
+[Backward](module.md#gradinput-backwardinput-gradoutput) drops out the gradients at the same location:
+```lua
+> =module:forward(x)
+ 0 4 0 0
+ 10 12 0 16
+[torch.DoubleTensor of dimension 2x4]
+
+> =module:backward(x,x:clone():fill(1))
+ 0 2 0 0
+ 2 2 0 2
+[torch.DoubleTensor of dimension 2x4]
+
+```
+In both cases the `gradOutput` and `input` are scaled by `1/(1-p)`, which in this case is `2`.
+
+During [evaluation](module.md#evaluate), `Dropout` does nothing more than
+forward the input such that all elements of the input are considered.
+```lua
+> module:evaluate()
+
+> module:forward(x)
+ 1 2 3 4
+ 5 6 7 8
+[torch.DoubleTensor of dimension 2x4]
+
+```
+
+We can return to training our model by first calling [Module:training()](module.md#training):
+```lua
+> module:training()
+
+> return module:forward(x)
+ 2 4 6 0
+ 0 0 0 16
+[torch.DoubleTensor of dimension 2x4]
+
+```
+
+When used, `Dropout` should normally be applied to the input of parameterized
+[Modules](module.md#nn.Module) like [Linear](#nn.Linear)
+or [SpatialConvolution](convolution.md#nn.SpatialConvolution).
+A `p` of `0.5` (the default) is usually okay for hidden layers.
+`Dropout` can sometimes be used successfully on the dataset inputs with a `p` around `0.2`.
+It sometimes works best following [Transfer](transfer.md) Modules
+like [ReLU](transfer.md#nn.ReLU). All this depends a great deal on the dataset so its up
+to the user to try different combinations.
+
+
<a name="nn.Abs"/>
## Abs ##
diff --git a/doc/table.md b/doc/table.md
index 4117117..60b6dea 100644
--- a/doc/table.md
+++ b/doc/table.md
@@ -9,6 +9,7 @@ This allows one to build very rich architectures:
* Table Conversion Modules convert between tables and Tensors:
* [SplitTable](#nn.SplitTable) : splits a Tensor into a table of Tensors;
* [JoinTable](#nn.JoinTable) : joins a table of Tensors into a Tensor;
+ * [ElementTable](#nn.ElementTable) : retrieve one element from a table;
* Pair Modules compute a measure like distance or similarity from a pair (table) of input Tensors :
* [PairwiseDistance](#nn.PairwiseDistance) : outputs the `p`-norm. distance between inputs;
* [DotProduct](#nn.DotProduct) : outputs the dot product (similarity) between inputs;
@@ -375,6 +376,100 @@ for i=1,100 do -- A few steps of training such a network..
end
```
+<a name="nn.ElementTable"/>
+## ElementTable ##
+
+`module` = `ElementTable(index)`
+
+Creates a module that takes a Table as input and outputs the element at index `index`.
+This can be either a Table or a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).
+
+The gradients of the non-`index` elements are zeroed Tensors of the same size. This is true regardless of the
+dept of the encapsulated Tensor as the function used internally to do so is recursive.
+
+Example 1:
+```lua
+> input = {torch.randn(2,3), torch.randn(2,1)}
+ [0.0002s]
+> =nn.ElementTable(1):forward(input)
+-0.3060 0.1398 0.2707
+ 0.0576 1.5455 0.0610
+[torch.DoubleTensor of dimension 2x3]
+
+ [0.0002s]
+> =nn.ElementTable(2):forward(input)
+ 2.3080
+-0.2955
+[torch.DoubleTensor of dimension 2x1]
+
+> =unpack(nn.ElementTable(1):backward(input, torch.randn(2,3)))
+-0.4891 -0.3495 -0.3182
+-2.0999 0.7381 -0.5312
+[torch.DoubleTensor of dimension 2x3]
+
+0
+0
+[torch.DoubleTensor of dimension 2x1]
+
+```
+
+Example 2:
+```lua
+> input = {torch.randn(2,3), {torch.randn(2,1), {torch.randn(2,2)}}}
+
+> =nn.ElementTable(2):forward(input)
+{
+ 1 : DoubleTensor - size: 2x1
+ 2 :
+ {
+ 1 : DoubleTensor - size: 2x2
+ }
+}
+
+> =unpack(nn.ElementTable(2):backward(input, {torch.randn(2,1), {torch.randn(2,2)}}))
+0 0 0
+0 0 0
+[torch.DoubleTensor of dimension 2x3]
+
+{
+ 1 : DoubleTensor - size: 2x1
+ 2 :
+ {
+ 1 : DoubleTensor - size: 2x2
+ }
+}
+
+> gradInput = nn.ElementTable(1):backward(input, torch.randn(2,3))
+
+> =gradInput
+{
+ 1 : DoubleTensor - size: 2x3
+ 2 :
+ {
+ 1 : DoubleTensor - size: 2x1
+ 2 :
+ {
+ 1 : DoubleTensor - size: 2x2
+ }
+ }
+}
+
+> =gradInput[1]
+-0.3400 -0.0404 1.1885
+ 1.2865 0.4107 0.6506
+[torch.DoubleTensor of dimension 2x3]
+
+> gradInput[2][1]
+0
+0
+[torch.DoubleTensor of dimension 2x1]
+
+> gradInput[2][2][1]
+0 0
+0 0
+[torch.DoubleTensor of dimension 2x2]
+
+```
<a name="nn.PairwiseDistance"/>
## PairwiseDistance ##
diff --git a/doc/transfer.md b/doc/transfer.md
index 0a5334a..0a47d7c 100755
--- a/doc/transfer.md
+++ b/doc/transfer.md
@@ -231,6 +231,23 @@ gnuplot.grid(true)
```
![](image/tanh.png)
+<a name="nn.ReLU"/>
+## ReLU ##
+
+Applies the rectified linear unit (`ReLU`) function element-wise to the input Tensor,
+thus outputting a Tensor of the same dimension.
+
+```lua
+ii=torch.linspace(-3,3)
+m=nn.ReLU()
+oo=m:forward(ii)
+go=torch.ones(100)
+gi=m:backward(ii,go)
+gnuplot.plot({'f(x)',ii,oo,'+-'},{'df/dx',ii,gi,'+-'})
+gnuplot.grid(true)
+```
+![](image/relu.png)
+
<a name="nn.AddConstant"/>
## AddConstant ##
diff --git a/init.lua b/init.lua
index 757c9ec..c8aff0e 100644
--- a/init.lua
+++ b/init.lua
@@ -26,6 +26,7 @@ include('Mul.lua')
include('MulConstant.lua')
include('Add.lua')
include('AddConstant.lua')
+include('Dropout.lua')
include('CAddTable.lua')
include('CDivTable.lua')
@@ -57,6 +58,7 @@ include('Sqrt.lua')
include('HardShrink.lua')
include('SoftShrink.lua')
include('Threshold.lua')
+include('ReLU.lua')
include('LookupTable.lua')
include('SpatialConvolution.lua')
@@ -85,6 +87,7 @@ include('ParallelTable.lua')
include('ConcatTable.lua')
include('SplitTable.lua')
include('JoinTable.lua')
+include('ElementTable.lua')
include('CriterionTable.lua')
include('Identity.lua')
diff --git a/test/test.lua b/test/test.lua
index 9ecc923..73426fb 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -60,6 +60,36 @@ function nntest.CMul()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.Dropout()
+ local p = 0.2 --prob of droping out a neuron
+ local input = torch.Tensor(1000):fill((1-p))
+ local module = nn.Dropout(p)
+ -- version 2
+ local output = module:forward(input)
+ mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output')
+ local gradInput = module:backward(input, input)
+ mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput')
+ -- version 1 (old nnx version)
+ local input = input:fill(1)
+ local module = nn.Dropout(p,true)
+ local output = module:forward(input)
+ mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output')
+ local gradInput = module:backward(input, input)
+ mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput')
+end
+
+function nntest.ReLU()
+ local input = torch.randn(3,4)
+ local gradOutput = torch.randn(3,4)
+ local module = nn.ReLU()
+ local output = module:forward(input)
+ local output2 = input:clone():gt(input, 0):cmul(input)
+ mytester:assertTensorEq(output, output2, 0.000001, 'ReLU output')
+ local gradInput = module:backward(input, gradOutput)
+ local gradInput2 = input:clone():gt(input, 0):cmul(gradOutput)
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'ReLU gradInput')
+end
+
function nntest.Exp()
local ini = math.random(10,20)
local inj = math.random(10,20)
@@ -1869,6 +1899,49 @@ function nntest.SplitTable()
end
end
+function nntest.ElementTable()
+ local input = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local gradOutputs = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local zeros = {
+ torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(),
+ {torch.Tensor(3,4,5):zero()},
+ {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}}
+ }
+ local function equal(t1, t2, msg)
+ if (torch.type(t1) == "table") then
+ for k, v in pairs(t2) do
+ equal(t1[k], t2[k])
+ end
+ else
+ mytester:assertTensorEq(t1, t2, 0.00001, msg)
+ end
+ end
+ local nonIdx = {2,3,4,1}
+ local module
+ for idx = 1,#input do
+ module = nn.ElementTable(idx)
+ local output = module:forward(input)
+ equal(output, input[idx], "output dimension " .. idx)
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+ end
+ module:float()
+ local idx = #input
+ local output = module:forward(input)
+ equal(output, input[idx], "type output")
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+end
function nntest.View()
local input = torch.rand(10)