From 83a3815dc70255c978405e8e966d7b02d580cc11 Mon Sep 17 00:00:00 2001 From: soumith Date: Thu, 14 May 2015 14:29:42 -0700 Subject: batchnorm is clonable by adding the running estimates to constructor fixing batchnorm tests --- test.lua | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'test.lua') diff --git a/test.lua b/test.lua index d82e470..1b4847c 100644 --- a/test.lua +++ b/test.lua @@ -438,6 +438,13 @@ function nntest.Sqrt() local err = out:dist(in1:sqrt()) mytester:assertlt(err, 1e-15, torch.typename(module) .. ' - forward err ') + -- Test zero inputs; we will avoid a div-by-zero by setting to zero + local zin = torch.DoubleTensor(5, 7):zero() + module:forward(zin) + local zgradout = torch.rand(5, 7) + local zgradin = module:backward(zin, zgradout) + mytester:assertTensorEq(zgradin, torch.DoubleTensor(5, 7):zero(), 0.000001, "error in sqrt backward singularity") + local ini = math.random(3,5) local inj = math.random(3,5) local ink = math.random(3,5) @@ -3471,7 +3478,7 @@ function nntest.BatchNormalization() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') -- batch norm without affine transform - module = nn.BatchNormalization(0) + module = nn.BatchNormalization(indim, 1e-5, 0.1, false) local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') @@ -3525,7 +3532,7 @@ function nntest.SpatialBatchNormalization() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') -- batch norm without affine transform - module = nn.SpatialBatchNormalization(0) + module = nn.SpatialBatchNormalization(indim, 1e-5, 0.1, false) local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') -- cgit v1.2.3