Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2015-09-14 02:06:00 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2015-09-14 02:06:00 +0300
commit0fb47b0b16c6180335b520661967cd6ea0656da7 (patch)
tree4e9b7661870c069501ba3fc6a0ebaa90949a7a70 /Dropout.lua
parent33196ac2e417cf722c959727a959759d6576324a (diff)
in-place dropout
Diffstat (limited to 'Dropout.lua')
-rw-r--r--Dropout.lua15
1 files changed, 12 insertions, 3 deletions
diff --git a/Dropout.lua b/Dropout.lua
index 66eda21..80a42af 100644
--- a/Dropout.lua
+++ b/Dropout.lua
@@ -1,9 +1,10 @@
local Dropout, Parent = torch.class('nn.Dropout', 'nn.Module')
-function Dropout:__init(p,v1)
+function Dropout:__init(p,v1,inplace)
Parent.__init(self)
self.p = p or 0.5
self.train = true
+ self.inplace = inplace
-- version 2 scales output during training instead of evaluation
self.v2 = not v1
if self.p >= 1 or self.p < 0 then
@@ -13,7 +14,11 @@ function Dropout:__init(p,v1)
end
function Dropout:updateOutput(input)
- self.output:resizeAs(input):copy(input)
+ if self.inplace then
+ self.output = input
+ else
+ self.output:resizeAs(input):copy(input)
+ end
if self.train then
self.noise:resizeAs(input)
self.noise:bernoulli(1-self.p)
@@ -29,7 +34,11 @@ end
function Dropout:updateGradInput(input, gradOutput)
if self.train then
- self.gradInput:resizeAs(gradOutput):copy(gradOutput)
+ if self.inplace then
+ self.gradInput = gradOutput
+ else
+ self.gradInput:resizeAs(gradOutput):copy(gradOutput)
+ end
self.gradInput:cmul(self.noise) -- simply mask the gradients with the noise vector
else
error('backprop only defined while training')