diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-08-03 22:22:55 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-08-03 22:22:55 +0300 |
commit | ab0ee1fe41b9843dc2469541b3f0ff856c12e547 (patch) | |
tree | 2b334d10fe361570d5d49cac0fe8ff9c74943c4e | |
parent | e9ef2d5281dec554724b816b520413c437fb1772 (diff) |
BN supports batchsize=1
-rw-r--r-- | test.lua | 5 |
1 files changed, 3 insertions, 2 deletions
@@ -978,9 +978,9 @@ local function BatchNormalization_backward(moduleName, mode, inputSize, backward end end -local function testBatchNormalization(name, dim, k) +local function testBatchNormalization(name, dim, k, batchsize) local function inputSize() - local inputSize = { torch.random(2,32), torch.random(1, k) } + local inputSize = { batchsize or torch.random(2,32), torch.random(1, k) } for i=1,dim do table.insert(inputSize, torch.random(1,k)) end @@ -1005,6 +1005,7 @@ end function cunntest.BatchNormalization() testBatchNormalization('BatchNormalization', 0, 128) + testBatchNormalization('BatchNormalization', 0, 128, 1) -- test batchsize=1 end function cunntest.SpatialBatchNormalization() |