From 919eecae9ff58e78c45817d836d85e0372e2fea1 Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Thu, 3 Aug 2017 15:14:14 -0400 Subject: BN supports batchsize=1 --- BatchNormalization.lua | 4 ++-- test.lua | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/BatchNormalization.lua b/BatchNormalization.lua index 8dfc576..9e57998 100644 --- a/BatchNormalization.lua +++ b/BatchNormalization.lua @@ -130,7 +130,7 @@ function BN:updateOutput(input) self.running_var:cdata(), self.save_mean:cdata(), self.save_std:cdata(), - self.train, + self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1 self.momentum, self.eps) @@ -162,7 +162,7 @@ local function backward(self, input, gradOutput, scale, gradInput, gradWeight, g self.running_var:cdata(), self.save_mean:cdata(), self.save_std:cdata(), - self.train, + self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1 scale, self.eps) diff --git a/test.lua b/test.lua index 35852fa..18117c0 100755 --- a/test.lua +++ b/test.lua @@ -7609,9 +7609,9 @@ function nntest.Replicate() mytester:assertTensorEq(vOutput2, expected2, precision, 'Wrong tiling of data when replicating batch vector.') end -local function testBatchNormalization(moduleName, dim, k) +local function testBatchNormalization(moduleName, dim, k, batchsize) local planes = torch.random(1,k) - local size = { torch.random(2, k), planes } + local size = { batchsize or torch.random(2, k), planes } for i=1,dim do table.insert(size, torch.random(1,k)) end @@ -7670,6 +7670,7 @@ end function nntest.BatchNormalization() testBatchNormalization('BatchNormalization', 0, 20) + testBatchNormalization('BatchNormalization', 0, 20, 1) -- test batchsize=1 end function nntest.SpatialBatchNormalization() -- cgit v1.2.3