diff options
Diffstat (limited to 'BatchNormalization.lua')
-rw-r--r-- | BatchNormalization.lua | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua index 8dfc576..9e57998 100644 --- a/BatchNormalization.lua +++ b/BatchNormalization.lua @@ -130,7 +130,7 @@ function BN:updateOutput(input) self.running_var:cdata(), self.save_mean:cdata(), self.save_std:cdata(), - self.train, + self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1 self.momentum, self.eps) @@ -162,7 +162,7 @@ local function backward(self, input, gradOutput, scale, gradInput, gradWeight, g self.running_var:cdata(), self.save_mean:cdata(), self.save_std:cdata(), - self.train, + self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1 scale, self.eps) |