From 83b5b6c6bbed0fb1a9457fc285d187341a040e90 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Wed, 24 Feb 2016 14:19:52 -0800 Subject: Add cudnn.BatchNormalization and cudnn.VolumetricBatchNormalization --- test/test.lua | 48 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 12 deletions(-) (limited to 'test') diff --git a/test/test.lua b/test/test.lua index 10ceabb..9449b88 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1134,16 +1134,11 @@ function cudnntest.SpatialLogSoftMax() mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') end -function cudnntest.SpatialBatchNormalization() - -- batch - local h = math.random(5,10) - local w = math.random(5,10) - local bsz = math.random(1, 32) - local from = math.random(1, 32) - local input = torch.randn(bsz,from,h,w):cuda() - local gradOutput = torch.randn(bsz,from,h,w):cuda() - local cbn = cudnn.SpatialBatchNormalization(from, 1e-3):cuda() - local gbn = nn.SpatialBatchNormalization(from, 1e-3):cuda() +local function testBatchNormalization(moduleName, inputSize) + local input = torch.randn(table.unpack(inputSize)):cuda() + local gradOutput = torch.randn(table.unpack(inputSize)):cuda() + local cbn = cudnn[moduleName](inputSize[2], 1e-3):cuda() + local gbn = nn[moduleName](inputSize[2], 1e-3):cuda() cbn.weight:copy(gbn.weight) cbn.bias:copy(gbn.bias) mytester:asserteq(cbn.running_mean:mean(), 0, 'error on BN running_mean init') @@ -1161,6 +1156,35 @@ function cudnntest.SpatialBatchNormalization() precision_backward, 'error in batch normalization (backward) ') end +function cudnntest.BatchNormalization() + local size = { + math.random(1, 32), + math.random(16, 256), + } + testBatchNormalization('BatchNormalization', size) +end + +function cudnntest.SpatialBatchNormalization() + local size = { + math.random(1, 32), + math.random(1, 32), + math.random(5, 10), + math.random(5, 10), + } + testBatchNormalization('SpatialBatchNormalization', size) +end + +function cudnntest.SpatialBatchNormalization() + local size = { + math.random(1, 32), + math.random(1, 32), + math.random(2, 6), + math.random(2, 6), + math.random(2, 6), + } + testBatchNormalization('VolumetricBatchNormalization', size) +end + function cudnntest.SpatialCrossEntropyCriterion() -- batch local numLabels = math.random(5,10) @@ -1188,11 +1212,11 @@ function cudnntest.SpatialCrossEntropyCriterion() ggi[{{}, {}, {i}, {j}}]:copy(ggi1) end end - + -- nn.CrossEntropy in contrast to cudnn.SpatialCrossEntropyCriterion cannot -- average over the last spatial dimensions because it is run in a loop ggi:div(h * w) - + local err = (gi - ggi):abs():max() mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') -- cgit v1.2.3