diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-08-03 22:14:14 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-08-03 22:14:14 +0300 |
commit | 919eecae9ff58e78c45817d836d85e0372e2fea1 (patch) | |
tree | 62cecb37a4ee210b903bb8344a3b5890f46daa0f /BatchNormalization.lua | |
parent | 3e3f1191293f6d6bc2f5b413fd428b80f62e570b (diff) |
BN supports batchsize=1
Diffstat (limited to 'BatchNormalization.lua')
-rw-r--r-- | BatchNormalization.lua | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua index 8dfc576..9e57998 100644 --- a/BatchNormalization.lua +++ b/BatchNormalization.lua @@ -130,7 +130,7 @@ function BN:updateOutput(input) self.running_var:cdata(), self.save_mean:cdata(), self.save_std:cdata(), - self.train, + self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1 self.momentum, self.eps) @@ -162,7 +162,7 @@ local function backward(self, input, gradOutput, scale, gradInput, gradWeight, g self.running_var:cdata(), self.save_mean:cdata(), self.save_std:cdata(), - self.train, + self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1 scale, self.eps) |