diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-12-23 21:52:31 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-12-23 21:52:31 +0300 |
commit | 289c6a2d91dfec05cc5c55105353e408b8541334 (patch) | |
tree | 14138f62030a9e312f6d2cc3502a1f263814440c /test.lua | |
parent | eecf41831e177b08dda3bcc51c6aa70c9df2f2d1 (diff) | |
parent | 4178c4ef3e17425940022047d7dd12a645d3ed11 (diff) |
Merge pull request #119 from nicholas-leonard/weightedeuclidean
WeightedEuclidean batch mode
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 37 |
1 files changed, 36 insertions, 1 deletions
@@ -477,7 +477,7 @@ function nntest.WeightedEuclidean() local inj = math.random(13,5) local input = torch.Tensor(ini):zero() local module = nn.WeightedEuclidean(ini,inj) - + local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') @@ -490,6 +490,41 @@ function nntest.WeightedEuclidean() local ferr,berr = jac.testIO(module,input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- test batch + local bs = math.random(3,5) + input:uniform(0,1) + local output = module:forward(input):clone() + module:zeroGradParameters() + local gradInput = module:backward(input, output):clone() + local params, gradParams = module:parameters() + for i=1,#params do + params[i] = params[i]:clone() + end + local input2 = input:view(1, -1):repeatTensor(bs, 1) + local output2 = module:forward(input2) + module:zeroGradParameters() + local gradInput2 = module:backward(input2, output2, 1/bs) + local params2, gradParams2 = module:parameters() + mytester:assertTensorEq(output2[bs-1], output, 0.000001, "error in batch updateOutput") + mytester:assertTensorEq(gradInput2[bs-1], gradInput, 0.000001, "error in batch updateGradInput") + mytester:assertTensorEq(gradParams[1], gradParams2[1], 0.000001, "error in batch accGradParameters (gradTemplates)") + mytester:assertTensorEq(gradParams[2], gradParams2[2], 0.000001, "error in batch accGradParameters (gradDiagCov)") + + input:zero() + module:zeroGradParameters() + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) + mytester:assertlt(err,precision, 'error on bias ') + + local ferr,berr = jac.testIO(module,input2) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end local function criterionJacobianTest1D(cri, input, target) |