diff options
author | ivpopov <ivpopov@google.com> | 2016-01-08 15:04:40 +0300 |
---|---|---|
committer | ivpopov <ivpopov@google.com> | 2016-01-08 15:04:40 +0300 |
commit | ff242ba2dd96794bd378c398915a2a30c10b174d (patch) | |
tree | 7a996f9d15c30359c6392c2a1be68e62df78c0da /Dropout.lua | |
parent | cd5924bcba3a9cb53be47e8461bcff5b312a6494 (diff) |
Update Dropout.lua
Diffstat (limited to 'Dropout.lua')
-rw-r--r-- | Dropout.lua | 46 |
1 files changed, 21 insertions, 25 deletions
diff --git a/Dropout.lua b/Dropout.lua index 8d11c45..920c981 100644 --- a/Dropout.lua +++ b/Dropout.lua @@ -14,42 +14,38 @@ function Dropout:__init(p,v1,inplace) end function Dropout:updateOutput(input) + if self.inplace then + self.output = input + else + self.output:resizeAs(input):copy(input) + end if self.p > 0 then - if self.inplace then - self.output = input - else - self.output:resizeAs(input):copy(input) - end if self.train then - self.noise:resizeAs(input) - self.noise:bernoulli(1-self.p) - if self.v2 then - self.noise:div(1-self.p) - end - self.output:cmul(self.noise) + self.noise:resizeAs(input) + self.noise:bernoulli(1-self.p) + if self.v2 then + self.noise:div(1-self.p) + end + self.output:cmul(self.noise) elseif not self.v2 then - self.output:mul(1-self.p) + self.output:mul(1-self.p) end - else - self.output = input end return self.output end function Dropout:updateGradInput(input, gradOutput) - if self.p > 0 then - if self.train then - if self.inplace then - self.gradInput = gradOutput - else - self.gradInput:resizeAs(gradOutput):copy(gradOutput) - end - self.gradInput:cmul(self.noise) -- simply mask the gradients with the noise vector + if self.train then + if self.inplace then + self.gradInput = gradOutput else - error('backprop only defined while training') + self.gradInput:resizeAs(gradOutput):copy(gradOutput) + end + if self.p > 0 then + self.gradInput:cmul(self.noise) -- simply mask the gradients with the noise vector end else - self.gradInput = gradOutput + error('backprop only defined while training') end return self.gradInput end @@ -59,5 +55,5 @@ function Dropout:setp(p) end function Dropout:__tostring__() - return string.format('%s(%f)', torch.type(self), self.p) + return string.format('%s(%f)', torch.type(self), self.p) end |