Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'BatchNormalization.lua')
-rw-r--r--BatchNormalization.lua4
1 files changed, 2 insertions, 2 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua
index 8dfc576..9e57998 100644
--- a/BatchNormalization.lua
+++ b/BatchNormalization.lua
@@ -130,7 +130,7 @@ function BN:updateOutput(input)
self.running_var:cdata(),
self.save_mean:cdata(),
self.save_std:cdata(),
- self.train,
+ self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1
self.momentum,
self.eps)
@@ -162,7 +162,7 @@ local function backward(self, input, gradOutput, scale, gradInput, gradWeight, g
self.running_var:cdata(),
self.save_mean:cdata(),
self.save_std:cdata(),
- self.train,
+ self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1
scale,
self.eps)