Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Kuen <xternalz@users.noreply.github.com>2016-09-02 21:42:03 +0300
committerSoumith Chintala <soumith@gmail.com>2016-09-02 21:42:03 +0300
commit40d3a23d006e89c97a29a6714efbde9e38c89ad4 (patch)
tree0951eb3f27e893c335957563a2b3f2334fd39a8b /SpatialDropout.lua
parent3cb23656d2cea404f17d9dd43cdc78ce5d3ec1a8 (diff)
Stochastic Inference for Dropout (#936)
* Stochastic Inference for Dropout
Diffstat (limited to 'SpatialDropout.lua')
-rw-r--r--SpatialDropout.lua5
1 files changed, 3 insertions, 2 deletions
diff --git a/SpatialDropout.lua b/SpatialDropout.lua
index 99cd0fc..4320061 100644
--- a/SpatialDropout.lua
+++ b/SpatialDropout.lua
@@ -1,15 +1,16 @@
local SpatialDropout, Parent = torch.class('nn.SpatialDropout', 'nn.Module')
-function SpatialDropout:__init(p)
+function SpatialDropout:__init(p,stochasticInference)
Parent.__init(self)
self.p = p or 0.5
self.train = true
+ self.stochastic_inference = stochasticInference or false
self.noise = torch.Tensor()
end
function SpatialDropout:updateOutput(input)
self.output:resizeAs(input):copy(input)
- if self.train then
+ if self.train or self.stochastic_inference then
if input:dim() == 4 then
self.noise:resize(input:size(1), input:size(2), 1, 1)
elseif input:dim() == 3 then