diff options
author | Jason Kuen <xternalz@users.noreply.github.com> | 2016-09-02 21:42:03 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-09-02 21:42:03 +0300 |
commit | 40d3a23d006e89c97a29a6714efbde9e38c89ad4 (patch) | |
tree | 0951eb3f27e893c335957563a2b3f2334fd39a8b /SpatialDropout.lua | |
parent | 3cb23656d2cea404f17d9dd43cdc78ce5d3ec1a8 (diff) |
Stochastic Inference for Dropout (#936)
* Stochastic Inference for Dropout
Diffstat (limited to 'SpatialDropout.lua')
-rw-r--r-- | SpatialDropout.lua | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/SpatialDropout.lua b/SpatialDropout.lua index 99cd0fc..4320061 100644 --- a/SpatialDropout.lua +++ b/SpatialDropout.lua @@ -1,15 +1,16 @@ local SpatialDropout, Parent = torch.class('nn.SpatialDropout', 'nn.Module') -function SpatialDropout:__init(p) +function SpatialDropout:__init(p,stochasticInference) Parent.__init(self) self.p = p or 0.5 self.train = true + self.stochastic_inference = stochasticInference or false self.noise = torch.Tensor() end function SpatialDropout:updateOutput(input) self.output:resizeAs(input):copy(input) - if self.train then + if self.train or self.stochastic_inference then if input:dim() == 4 then self.noise:resize(input:size(1), input:size(2), 1, 1) elseif input:dim() == 3 then |