Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas LĂ©onard <nick@nikopia.org>2017-08-03 23:46:50 +0300
committerGitHub <noreply@github.com>2017-08-03 23:46:50 +0300
commit90afcbf93b629b74cbb3bd76b12a0a8c389195e9 (patch)
tree2b334d10fe361570d5d49cac0fe8ff9c74943c4e
parente9ef2d5281dec554724b816b520413c437fb1772 (diff)
parentab0ee1fe41b9843dc2469541b3f0ff856c12e547 (diff)
Merge pull request #480 from nicholas-leonard/BN-batchsize1
BN supports batchsize=1
-rw-r--r--test.lua5
1 files changed, 3 insertions, 2 deletions
diff --git a/test.lua b/test.lua
index fb65bd9..02b63d4 100644
--- a/test.lua
+++ b/test.lua
@@ -978,9 +978,9 @@ local function BatchNormalization_backward(moduleName, mode, inputSize, backward
end
end
-local function testBatchNormalization(name, dim, k)
+local function testBatchNormalization(name, dim, k, batchsize)
local function inputSize()
- local inputSize = { torch.random(2,32), torch.random(1, k) }
+ local inputSize = { batchsize or torch.random(2,32), torch.random(1, k) }
for i=1,dim do
table.insert(inputSize, torch.random(1,k))
end
@@ -1005,6 +1005,7 @@ end
function cunntest.BatchNormalization()
testBatchNormalization('BatchNormalization', 0, 128)
+ testBatchNormalization('BatchNormalization', 0, 128, 1) -- test batchsize=1
end
function cunntest.SpatialBatchNormalization()