diff options
-rw-r--r-- | Dropout.lua | 46 | ||||
-rw-r--r-- | ElementTable.lua | 34 | ||||
-rw-r--r-- | ReLU.lua | 5 | ||||
-rw-r--r-- | Sequential.lua | 10 | ||||
-rw-r--r-- | doc/image/relu.png | bin | 0 -> 19636 bytes | |||
-rw-r--r-- | doc/module.md | 8 | ||||
-rw-r--r-- | doc/simple.md | 83 | ||||
-rw-r--r-- | doc/table.md | 95 | ||||
-rwxr-xr-x | doc/transfer.md | 17 | ||||
-rw-r--r-- | init.lua | 3 | ||||
-rw-r--r-- | test/test.lua | 73 |
11 files changed, 374 insertions, 0 deletions
diff --git a/Dropout.lua b/Dropout.lua new file mode 100644 index 0000000..a92faf2 --- /dev/null +++ b/Dropout.lua @@ -0,0 +1,46 @@ +local Dropout, Parent = torch.class('nn.Dropout', 'nn.Module') + +function Dropout:__init(p,v1) + Parent.__init(self) + self.p = p or 0.5 + self.train = true + -- version 2 scales output during training instead of evaluation + self.v2 = not v1 + if self.p >= 1 or self.p < 0 then + error('<Dropout> illegal percentage, must be 0 <= p < 1') + end + self.noise = torch.Tensor() + self.fnoise = torch.Tensor() +end + +function Dropout:updateOutput(input) + self.output:resizeAs(input):copy(input) + if self.train then + self.fnoise = self.fnoise:float() + self.fnoise:resize(input:size()) + self.noise:resizeAs(input) + self.fnoise:bernoulli(1-self.p) + self.noise:copy(self.fnoise) + if self.v2 then + self.noise:div(1-self.p) + end + self.output:cmul(self.noise) + elseif not self.v2 then + self.output:mul(1-self.p) + end + return self.output +end + +function Dropout:updateGradInput(input, gradOutput) + if self.train then + self.gradInput:resizeAs(gradOutput):copy(gradOutput) + self.gradInput:cmul(self.noise) -- simply mask the gradients with the noise vector + else + error('backprop only defined while training') + end + return self.gradInput +end + +function Dropout:setp(p) + self.p = p +end diff --git a/ElementTable.lua b/ElementTable.lua new file mode 100644 index 0000000..cb3ff0f --- /dev/null +++ b/ElementTable.lua @@ -0,0 +1,34 @@ +local ElementTable, parent = torch.class('nn.ElementTable', 'nn.Module') + +function ElementTable:__init(index) + parent.__init(self) + self.index = index + self.gradInput = {} +end + +function ElementTable:updateOutput(input) + self.output = input[self.index] + return self.output +end + +function ElementTable:updateGradInput(input, gradOutput) + if #self.gradInput == 0 then + local function zeroTableCopy(t1, t2) + for k, v in pairs(t2) do + if (torch.type(v) == "table") then + t1[k] = zeroTableCopy(t1[k] or {}, t2[k]) + else + t1[k] = v:clone():zero() + end + end + return t1 + end + zeroTableCopy(self.gradInput, input) + end + self.gradInput[self.index] = gradOutput + return self.gradInput +end + +function ElementTable:type(type) + self.gradInput = {} +end diff --git a/ReLU.lua b/ReLU.lua new file mode 100644 index 0000000..65bf196 --- /dev/null +++ b/ReLU.lua @@ -0,0 +1,5 @@ +local ReLU, Parent = torch.class('nn.ReLU', 'nn.Threshold') + +function ReLU:__init(p) + Parent.__init(self,0,0) +end diff --git a/Sequential.lua b/Sequential.lua index b43bd99..ec3247b 100644 --- a/Sequential.lua +++ b/Sequential.lua @@ -13,6 +13,16 @@ function Sequential:add(module) return self end +function Sequential:insert(module, index) + index = index or (#self.modules + 1) + if index > (#self.modules + 1) then + error"index should be contiguous to existing modules" + end + table.insert(self.modules, index, module) + self.output = self.modules[#self.modules].output + self.gradInput = self.modules[1].gradInput +end + function Sequential:size() return #self.modules end diff --git a/doc/image/relu.png b/doc/image/relu.png Binary files differnew file mode 100644 index 0000000..d60d2ab --- /dev/null +++ b/doc/image/relu.png diff --git a/doc/module.md b/doc/module.md index c8bf501..273fec2 100644 --- a/doc/module.md +++ b/doc/module.md @@ -274,3 +274,11 @@ Custom modules should not override this function. They should instead override [ This function will go over all the weights and gradWeights and make them view into a single tensor (one for weights and one for gradWeights). Since the storage of every weight and gradWeight is changed, this function should be called only once on a given network. +<a name="nn.Module.training"/> +### training() ### +This sets the mode of the Module (or sub-modules) to `train=true`. This is useful for modules like [Dropout](simple.md#nn.Dropout) that have a different behaviour during training vs evaluation. + +<a name="nn.Module.evaluate"/> +### evaluate() ### +This sets the mode of the Module (or sub-modules) to `train=false`. This is useful for modules like [Dropout](simple.md#nn.Dropout) that have a different behaviour during training vs evaluation. + diff --git a/doc/simple.md b/doc/simple.md index ad883b7..9a5543d 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -79,6 +79,89 @@ values in a a vector where all other elements are zeros. The indices should not exceed the stated dimensions of the input to the layer (10000 in the example). + +<a name="nn.Dropout"/> +## Dropout ## + +`module` = `nn.Dropout(p)` + +During training, `Dropout` masks parts of the `input` using binary samples from +a [bernoulli](http://en.wikipedia.org/wiki/Bernoulli_distribution) distribution. +Each `input` element has a probability of `p` of being dropped, i.e having its +commensurate output element be zero. This has proven an effective technique for +regularization and preventing the co-adaptation of neurons +(see [Hinton et al. 2012](http://arxiv.org/abs/1207.0580)). + +Furthermore, the ouputs are scaled by a factor of `1/(1-p)` during training. This allows the +`input` to be simply forwarded as-is during evaluation. + +In this example, we demonstrate how the call to [forward](module.md#output-forwardinput) samples +different `outputs` to dropout (the zeros) given the same `input`: +```lua +module = nn.Dropout() + +> x=torch.Tensor{{1,2,3,4},{5,6,7,8}} + +> =module:forward(x) + 2 0 0 8 + 10 0 14 0 +[torch.DoubleTensor of dimension 2x4] + +> =module:forward(x) + 0 0 6 0 + 10 0 0 0 +[torch.DoubleTensor of dimension 2x4] + +``` + +[Backward](module.md#gradinput-backwardinput-gradoutput) drops out the gradients at the same location: +```lua +> =module:forward(x) + 0 4 0 0 + 10 12 0 16 +[torch.DoubleTensor of dimension 2x4] + +> =module:backward(x,x:clone():fill(1)) + 0 2 0 0 + 2 2 0 2 +[torch.DoubleTensor of dimension 2x4] + +``` +In both cases the `gradOutput` and `input` are scaled by `1/(1-p)`, which in this case is `2`. + +During [evaluation](module.md#evaluate), `Dropout` does nothing more than +forward the input such that all elements of the input are considered. +```lua +> module:evaluate() + +> module:forward(x) + 1 2 3 4 + 5 6 7 8 +[torch.DoubleTensor of dimension 2x4] + +``` + +We can return to training our model by first calling [Module:training()](module.md#training): +```lua +> module:training() + +> return module:forward(x) + 2 4 6 0 + 0 0 0 16 +[torch.DoubleTensor of dimension 2x4] + +``` + +When used, `Dropout` should normally be applied to the input of parameterized +[Modules](module.md#nn.Module) like [Linear](#nn.Linear) +or [SpatialConvolution](convolution.md#nn.SpatialConvolution). +A `p` of `0.5` (the default) is usually okay for hidden layers. +`Dropout` can sometimes be used successfully on the dataset inputs with a `p` around `0.2`. +It sometimes works best following [Transfer](transfer.md) Modules +like [ReLU](transfer.md#nn.ReLU). All this depends a great deal on the dataset so its up +to the user to try different combinations. + + <a name="nn.Abs"/> ## Abs ## diff --git a/doc/table.md b/doc/table.md index 4117117..60b6dea 100644 --- a/doc/table.md +++ b/doc/table.md @@ -9,6 +9,7 @@ This allows one to build very rich architectures: * Table Conversion Modules convert between tables and Tensors: * [SplitTable](#nn.SplitTable) : splits a Tensor into a table of Tensors; * [JoinTable](#nn.JoinTable) : joins a table of Tensors into a Tensor; + * [ElementTable](#nn.ElementTable) : retrieve one element from a table; * Pair Modules compute a measure like distance or similarity from a pair (table) of input Tensors : * [PairwiseDistance](#nn.PairwiseDistance) : outputs the `p`-norm. distance between inputs; * [DotProduct](#nn.DotProduct) : outputs the dot product (similarity) between inputs; @@ -375,6 +376,100 @@ for i=1,100 do -- A few steps of training such a network.. end ``` +<a name="nn.ElementTable"/> +## ElementTable ## + +`module` = `ElementTable(index)` + +Creates a module that takes a Table as input and outputs the element at index `index`. +This can be either a Table or a [Tensor](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor). + +The gradients of the non-`index` elements are zeroed Tensors of the same size. This is true regardless of the +dept of the encapsulated Tensor as the function used internally to do so is recursive. + +Example 1: +```lua +> input = {torch.randn(2,3), torch.randn(2,1)} + [0.0002s] +> =nn.ElementTable(1):forward(input) +-0.3060 0.1398 0.2707 + 0.0576 1.5455 0.0610 +[torch.DoubleTensor of dimension 2x3] + + [0.0002s] +> =nn.ElementTable(2):forward(input) + 2.3080 +-0.2955 +[torch.DoubleTensor of dimension 2x1] + +> =unpack(nn.ElementTable(1):backward(input, torch.randn(2,3))) +-0.4891 -0.3495 -0.3182 +-2.0999 0.7381 -0.5312 +[torch.DoubleTensor of dimension 2x3] + +0 +0 +[torch.DoubleTensor of dimension 2x1] + +``` + +Example 2: +```lua +> input = {torch.randn(2,3), {torch.randn(2,1), {torch.randn(2,2)}}} + +> =nn.ElementTable(2):forward(input) +{ + 1 : DoubleTensor - size: 2x1 + 2 : + { + 1 : DoubleTensor - size: 2x2 + } +} + +> =unpack(nn.ElementTable(2):backward(input, {torch.randn(2,1), {torch.randn(2,2)}})) +0 0 0 +0 0 0 +[torch.DoubleTensor of dimension 2x3] + +{ + 1 : DoubleTensor - size: 2x1 + 2 : + { + 1 : DoubleTensor - size: 2x2 + } +} + +> gradInput = nn.ElementTable(1):backward(input, torch.randn(2,3)) + +> =gradInput +{ + 1 : DoubleTensor - size: 2x3 + 2 : + { + 1 : DoubleTensor - size: 2x1 + 2 : + { + 1 : DoubleTensor - size: 2x2 + } + } +} + +> =gradInput[1] +-0.3400 -0.0404 1.1885 + 1.2865 0.4107 0.6506 +[torch.DoubleTensor of dimension 2x3] + +> gradInput[2][1] +0 +0 +[torch.DoubleTensor of dimension 2x1] + +> gradInput[2][2][1] +0 0 +0 0 +[torch.DoubleTensor of dimension 2x2] + +``` <a name="nn.PairwiseDistance"/> ## PairwiseDistance ## diff --git a/doc/transfer.md b/doc/transfer.md index 0a5334a..0a47d7c 100755 --- a/doc/transfer.md +++ b/doc/transfer.md @@ -231,6 +231,23 @@ gnuplot.grid(true) ``` ![](image/tanh.png) +<a name="nn.ReLU"/> +## ReLU ## + +Applies the rectified linear unit (`ReLU`) function element-wise to the input Tensor, +thus outputting a Tensor of the same dimension. + +```lua +ii=torch.linspace(-3,3) +m=nn.ReLU() +oo=m:forward(ii) +go=torch.ones(100) +gi=m:backward(ii,go) +gnuplot.plot({'f(x)',ii,oo,'+-'},{'df/dx',ii,gi,'+-'}) +gnuplot.grid(true) +``` +![](image/relu.png) + <a name="nn.AddConstant"/> ## AddConstant ## @@ -26,6 +26,7 @@ include('Mul.lua') include('MulConstant.lua') include('Add.lua') include('AddConstant.lua') +include('Dropout.lua') include('CAddTable.lua') include('CDivTable.lua') @@ -57,6 +58,7 @@ include('Sqrt.lua') include('HardShrink.lua') include('SoftShrink.lua') include('Threshold.lua') +include('ReLU.lua') include('LookupTable.lua') include('SpatialConvolution.lua') @@ -85,6 +87,7 @@ include('ParallelTable.lua') include('ConcatTable.lua') include('SplitTable.lua') include('JoinTable.lua') +include('ElementTable.lua') include('CriterionTable.lua') include('Identity.lua') diff --git a/test/test.lua b/test/test.lua index 9ecc923..73426fb 100644 --- a/test/test.lua +++ b/test/test.lua @@ -60,6 +60,36 @@ function nntest.CMul() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end +function nntest.Dropout() + local p = 0.2 --prob of droping out a neuron + local input = torch.Tensor(1000):fill((1-p)) + local module = nn.Dropout(p) + -- version 2 + local output = module:forward(input) + mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output') + local gradInput = module:backward(input, input) + mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput') + -- version 1 (old nnx version) + local input = input:fill(1) + local module = nn.Dropout(p,true) + local output = module:forward(input) + mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output') + local gradInput = module:backward(input, input) + mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput') +end + +function nntest.ReLU() + local input = torch.randn(3,4) + local gradOutput = torch.randn(3,4) + local module = nn.ReLU() + local output = module:forward(input) + local output2 = input:clone():gt(input, 0):cmul(input) + mytester:assertTensorEq(output, output2, 0.000001, 'ReLU output') + local gradInput = module:backward(input, gradOutput) + local gradInput2 = input:clone():gt(input, 0):cmul(gradOutput) + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'ReLU gradInput') +end + function nntest.Exp() local ini = math.random(10,20) local inj = math.random(10,20) @@ -1869,6 +1899,49 @@ function nntest.SplitTable() end end +function nntest.ElementTable() + local input = { + torch.rand(3,4,5), torch.rand(3,4,5), + {torch.rand(3,4,5)}, + {torch.rand(3,4,5), {torch.rand(3,4,5)}} + } + local gradOutputs = { + torch.rand(3,4,5), torch.rand(3,4,5), + {torch.rand(3,4,5)}, + {torch.rand(3,4,5), {torch.rand(3,4,5)}} + } + local zeros = { + torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(), + {torch.Tensor(3,4,5):zero()}, + {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}} + } + local function equal(t1, t2, msg) + if (torch.type(t1) == "table") then + for k, v in pairs(t2) do + equal(t1[k], t2[k]) + end + else + mytester:assertTensorEq(t1, t2, 0.00001, msg) + end + end + local nonIdx = {2,3,4,1} + local module + for idx = 1,#input do + module = nn.ElementTable(idx) + local output = module:forward(input) + equal(output, input[idx], "output dimension " .. idx) + local gradInput = module:backward(input, gradOutputs[idx]) + equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) + end + module:float() + local idx = #input + local output = module:forward(input) + equal(output, input[idx], "type output") + local gradInput = module:backward(input, gradOutputs[idx]) + equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) +end function nntest.View() local input = torch.rand(10) |