Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-08-03 22:14:14 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-08-03 22:14:14 +0300
commit919eecae9ff58e78c45817d836d85e0372e2fea1 (patch)
tree62cecb37a4ee210b903bb8344a3b5890f46daa0f /BatchNormalization.lua
parent3e3f1191293f6d6bc2f5b413fd428b80f62e570b (diff)
BN supports batchsize=1
Diffstat (limited to 'BatchNormalization.lua')
-rw-r--r--BatchNormalization.lua4
1 files changed, 2 insertions, 2 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua
index 8dfc576..9e57998 100644
--- a/BatchNormalization.lua
+++ b/BatchNormalization.lua
@@ -130,7 +130,7 @@ function BN:updateOutput(input)
self.running_var:cdata(),
self.save_mean:cdata(),
self.save_std:cdata(),
- self.train,
+ self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1
self.momentum,
self.eps)
@@ -162,7 +162,7 @@ local function backward(self, input, gradOutput, scale, gradInput, gradWeight, g
self.running_var:cdata(),
self.save_mean:cdata(),
self.save_std:cdata(),
- self.train,
+ self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1
scale,
self.eps)