Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-08-03 22:14:14 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-08-03 22:14:14 +0300
commit919eecae9ff58e78c45817d836d85e0372e2fea1 (patch)
tree62cecb37a4ee210b903bb8344a3b5890f46daa0f
parent3e3f1191293f6d6bc2f5b413fd428b80f62e570b (diff)
BN supports batchsize=1
-rw-r--r--BatchNormalization.lua4
-rwxr-xr-xtest.lua5
2 files changed, 5 insertions, 4 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua
index 8dfc576..9e57998 100644
--- a/BatchNormalization.lua
+++ b/BatchNormalization.lua
@@ -130,7 +130,7 @@ function BN:updateOutput(input)
self.running_var:cdata(),
self.save_mean:cdata(),
self.save_std:cdata(),
- self.train,
+ self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1
self.momentum,
self.eps)
@@ -162,7 +162,7 @@ local function backward(self, input, gradOutput, scale, gradInput, gradWeight, g
self.running_var:cdata(),
self.save_mean:cdata(),
self.save_std:cdata(),
- self.train,
+ self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1
scale,
self.eps)
diff --git a/test.lua b/test.lua
index 35852fa..18117c0 100755
--- a/test.lua
+++ b/test.lua
@@ -7609,9 +7609,9 @@ function nntest.Replicate()
mytester:assertTensorEq(vOutput2, expected2, precision, 'Wrong tiling of data when replicating batch vector.')
end
-local function testBatchNormalization(moduleName, dim, k)
+local function testBatchNormalization(moduleName, dim, k, batchsize)
local planes = torch.random(1,k)
- local size = { torch.random(2, k), planes }
+ local size = { batchsize or torch.random(2, k), planes }
for i=1,dim do
table.insert(size, torch.random(1,k))
end
@@ -7670,6 +7670,7 @@ end
function nntest.BatchNormalization()
testBatchNormalization('BatchNormalization', 0, 20)
+ testBatchNormalization('BatchNormalization', 0, 20, 1) -- test batchsize=1
end
function nntest.SpatialBatchNormalization()