diff options
author | Nicholas LĂ©onard <nick@nikopia.org> | 2017-08-03 23:46:44 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-08-03 23:46:44 +0300 |
commit | bae729acce1930aa46be5c6ca0d7272f7eba406e (patch) | |
tree | 62cecb37a4ee210b903bb8344a3b5890f46daa0f | |
parent | 3e3f1191293f6d6bc2f5b413fd428b80f62e570b (diff) | |
parent | 919eecae9ff58e78c45817d836d85e0372e2fea1 (diff) |
Merge pull request #1271 from nicholas-leonard/BN-batchsize1
BN supports batchsize=1
-rw-r--r-- | BatchNormalization.lua | 4 | ||||
-rwxr-xr-x | test.lua | 5 |
2 files changed, 5 insertions, 4 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua index 8dfc576..9e57998 100644 --- a/BatchNormalization.lua +++ b/BatchNormalization.lua @@ -130,7 +130,7 @@ function BN:updateOutput(input) self.running_var:cdata(), self.save_mean:cdata(), self.save_std:cdata(), - self.train, + self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1 self.momentum, self.eps) @@ -162,7 +162,7 @@ local function backward(self, input, gradOutput, scale, gradInput, gradWeight, g self.running_var:cdata(), self.save_mean:cdata(), self.save_std:cdata(), - self.train, + self.train and (input:size(1) > 1), -- don't update running_[var,mean] when batchsize = 1 scale, self.eps) @@ -7609,9 +7609,9 @@ function nntest.Replicate() mytester:assertTensorEq(vOutput2, expected2, precision, 'Wrong tiling of data when replicating batch vector.') end -local function testBatchNormalization(moduleName, dim, k) +local function testBatchNormalization(moduleName, dim, k, batchsize) local planes = torch.random(1,k) - local size = { torch.random(2, k), planes } + local size = { batchsize or torch.random(2, k), planes } for i=1,dim do table.insert(size, torch.random(1,k)) end @@ -7670,6 +7670,7 @@ end function nntest.BatchNormalization() testBatchNormalization('BatchNormalization', 0, 20) + testBatchNormalization('BatchNormalization', 0, 20, 1) -- test batchsize=1 end function nntest.SpatialBatchNormalization() |