Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Tompson <jonathantompson@gmail.com>2015-03-27 00:39:43 +0300
committerJonathan Tompson <jonathantompson@gmail.com>2015-03-27 00:39:43 +0300
commit91b494dd4d53ace615651ba2574f90ffc7d4a2df (patch)
tree40566777bdb0734c7a278fa50723e41a62815760
parent6577b5535eb136e4668dbd1fe509c79df19678e1 (diff)
Added SpatialDropout + doc + test.
-rwxr-xr-xSpatialDropout.lua43
-rwxr-xr-x[-rw-r--r--]doc/simple.md11
-rw-r--r--init.lua1
-rw-r--r--test.lua30
4 files changed, 84 insertions, 1 deletions
diff --git a/SpatialDropout.lua b/SpatialDropout.lua
new file mode 100755
index 0000000..6736783
--- /dev/null
+++ b/SpatialDropout.lua
@@ -0,0 +1,43 @@
+local SpatialDropout, Parent = torch.class('nn.SpatialDropout', 'nn.Module')
+
+function SpatialDropout:__init(p)
+ Parent.__init(self)
+ self.p = p or 0.5
+ self.train = true
+ self.noise = torch.Tensor()
+end
+
+function SpatialDropout:updateOutput(input)
+ self.output:resizeAs(input):copy(input)
+ if self.train then
+ if input:dim() == 4 then
+ self.noise:resize(input:size(1), input:size(2), 1, 1)
+ elseif input:dim() == 3 then
+ self.noise:resize(input:size(1), 1, 1)
+ else
+ error('Input must be 4D (nbatch, nfeat, h, w) or 2D (nfeat, h, w)')
+ end
+ self.noise:bernoulli(1-self.p)
+ -- We expand the random dropouts to the entire feature map because the
+ -- features are likely correlated accross the map and so the dropout
+ -- should also be correlated.
+ self.output:cmul(torch.expandAs(self.noise, input))
+ else
+ self.output:mul(1-self.p)
+ end
+ return self.output
+end
+
+function SpatialDropout:updateGradInput(input, gradOutput)
+ if self.train then
+ self.gradInput:resizeAs(gradOutput):copy(gradOutput)
+ self.gradInput:cmul(torch.expandAs(self.noise, input)) -- simply mask the gradients with the noise vector
+ else
+ error('backprop only defined while training')
+ end
+ return self.gradInput
+end
+
+function SpatialDropout:setp(p)
+ self.p = p
+end
diff --git a/doc/simple.md b/doc/simple.md
index 7d806a6..a35f852 100644..100755
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -33,6 +33,7 @@ and providing affine transformations :
* [BatchNormalization](#nn.BatchNormalization) - mean/std normalization over the mini-batch inputs (with an optional affine transform) ;
* [Identity](#nn.Identity) : forward input as-is to output (useful with [ParallelTable](table.md#nn.ParallelTable));
* [Dropout](#nn.Dropout) : masks parts of the `input` using binary samples from a [bernoulli](http://en.wikipedia.org/wiki/Bernoulli_distribution) distribution ;
+ * [SpatialDropout](#nn.SpatialDropout) : Same as Dropout but for spatial inputs where adjacent pixels are strongly correlated ;
* [Padding](#nn.Padding) : adds padding to a dimension ;
* [L1Penalty](#nn.L1Penalty) : adds an L1 penalty to an input (for sparsity);
@@ -196,6 +197,16 @@ It sometimes works best following [Transfer](transfer.md) Modules
like [ReLU](transfer.md#nn.ReLU). All this depends a great deal on the dataset so its up
to the user to try different combinations.
+<a name="nn.SpatialDropout"/>
+## SpatialDropout ##
+
+`module` = `nn.SpatialDropout(p)`
+
+This version performs the same function as ```nn.Dropout```, however it assumes the 2 right-most dimensions of the input are spatial, performs one Bernoulli trial per output feature when training, and extends this dropout value across the entire feature map.
+
+As described in the paper "Efficient Object Localization Using Convolutional Networks" (http://arxiv.org/abs/1411.4280), if adjacent pixels within feature maps are strongly correlated (as is normally the case in early convolution layers) then iid dropout will not regularize the activations and will otherwise just result in an effective learning rate decrease. In this case, ```nn.SpatialDropout``` will help promote independence between feature maps and should be used instead.
+
+```nn.SpatialDropout``` accepts 3D or 4D inputs. If the input is 3D than a layout of (features x height x width) is assumed and for 4D (batch x features x height x width) is assumed.
<a name="nn.Abs"/>
## Abs ##
diff --git a/init.lua b/init.lua
index f344b54..65039fc 100644
--- a/init.lua
+++ b/init.lua
@@ -32,6 +32,7 @@ include('MulConstant.lua')
include('Add.lua')
include('AddConstant.lua')
include('Dropout.lua')
+include('SpatialDropout.lua')
include('CAddTable.lua')
include('CDivTable.lua')
diff --git a/test.lua b/test.lua
index a15cd61..27a1747 100644
--- a/test.lua
+++ b/test.lua
@@ -37,7 +37,6 @@ for test_name, component in pairs(tostringTestModules) do
end
end
-
function nntest.Add()
local inj_vals = {math.random(3,5), 1} -- Also test the inj = 1 spatial case
local ini = math.random(3,5)
@@ -172,6 +171,35 @@ function nntest.Dropout()
mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput')
end
+function nntest.SpatialDropout()
+ local p = 0.2 --prob of dropiing out a neuron
+ local w = math.random(1,5)
+ local h = math.random(1,5)
+ local nfeats = 1000
+ local input = torch.Tensor(nfeats, w, h):fill(1)
+ local module = nn.SpatialDropout(p)
+ module.train = true
+ local output = module:forward(input)
+ mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output')
+ local gradInput = module:backward(input, input)
+ mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput')
+end
+
+function nntest.SpatialDropoutBatch()
+ local p = 0.2 --prob of dropiing out a neuron
+ local bsz = math.random(1,5)
+ local w = math.random(1,5)
+ local h = math.random(1,5)
+ local nfeats = 1000
+ local input = torch.Tensor(bsz, nfeats, w, h):fill(1)
+ local module = nn.SpatialDropout(p)
+ module.train = true
+ local output = module:forward(input)
+ mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output')
+ local gradInput = module:backward(input, input)
+ mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput')
+end
+
function nntest.ReLU()
local input = torch.randn(3,4)
local gradOutput = torch.randn(3,4)