Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-10 00:23:09 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-10 00:23:09 +0400
commitf08c7fcae24a0302043701e5e4c36a8557b431a9 (patch)
tree895323a89309a83c8f20196fb16198916c606c40
parentcfb269369c9f1a1b523b2d9865df241868a601db (diff)
added Dropout version 2
-rw-r--r--Dropout.lua9
-rw-r--r--test/test.lua10
2 files changed, 16 insertions, 3 deletions
diff --git a/Dropout.lua b/Dropout.lua
index ac6c463..a92faf2 100644
--- a/Dropout.lua
+++ b/Dropout.lua
@@ -1,9 +1,11 @@
local Dropout, Parent = torch.class('nn.Dropout', 'nn.Module')
-function Dropout:__init(p)
+function Dropout:__init(p,v1)
Parent.__init(self)
self.p = p or 0.5
self.train = true
+ -- version 2 scales output during training instead of evaluation
+ self.v2 = not v1
if self.p >= 1 or self.p < 0 then
error('<Dropout> illegal percentage, must be 0 <= p < 1')
end
@@ -19,8 +21,11 @@ function Dropout:updateOutput(input)
self.noise:resizeAs(input)
self.fnoise:bernoulli(1-self.p)
self.noise:copy(self.fnoise)
+ if self.v2 then
+ self.noise:div(1-self.p)
+ end
self.output:cmul(self.noise)
- else
+ elseif not self.v2 then
self.output:mul(1-self.p)
end
return self.output
diff --git a/test/test.lua b/test/test.lua
index a3c816f..cc82e9e 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -62,8 +62,16 @@ end
function nntest.Dropout()
local p = 0.2 --prob of droping out a neuron
- local input = torch.Tensor(1000):fill(1)
+ local input = torch.Tensor(1000):fill((1-p))
local module = nn.Dropout(p)
+ -- version 2
+ local output = module:forward(input)
+ mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output')
+ local gradInput = module:backward(input, input)
+ mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput')
+ -- version 1 (old nnx version)
+ local input = input:fill(1)
+ local module = nn.Dropout(p,true)
local output = module:forward(input)
mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output')
local gradInput = module:backward(input, input)