diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-12-09 21:25:50 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-12-23 20:38:58 +0300 |
commit | 4178c4ef3e17425940022047d7dd12a645d3ed11 (patch) | |
tree | ff1f29874f85695948b1a4efc5a98ea0c1c33900 /test.lua | |
parent | d5ab2ca3c2b4d4cba7bdfafd8d86daa63bea71f7 (diff) |
WeightedEuclidean batch mode
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 37 |
1 files changed, 36 insertions, 1 deletions
@@ -456,7 +456,7 @@ function nntest.WeightedEuclidean() local inj = math.random(13,5) local input = torch.Tensor(ini):zero() local module = nn.WeightedEuclidean(ini,inj) - + local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') @@ -469,6 +469,41 @@ function nntest.WeightedEuclidean() local ferr,berr = jac.testIO(module,input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- test batch + local bs = math.random(3,5) + input:uniform(0,1) + local output = module:forward(input):clone() + module:zeroGradParameters() + local gradInput = module:backward(input, output):clone() + local params, gradParams = module:parameters() + for i=1,#params do + params[i] = params[i]:clone() + end + local input2 = input:view(1, -1):repeatTensor(bs, 1) + local output2 = module:forward(input2) + module:zeroGradParameters() + local gradInput2 = module:backward(input2, output2, 1/bs) + local params2, gradParams2 = module:parameters() + mytester:assertTensorEq(output2[bs-1], output, 0.000001, "error in batch updateOutput") + mytester:assertTensorEq(gradInput2[bs-1], gradInput, 0.000001, "error in batch updateGradInput") + mytester:assertTensorEq(gradParams[1], gradParams2[1], 0.000001, "error in batch accGradParameters (gradTemplates)") + mytester:assertTensorEq(gradParams[2], gradParams2[2], 0.000001, "error in batch accGradParameters (gradDiagCov)") + + input:zero() + module:zeroGradParameters() + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) + mytester:assertlt(err,precision, 'error on weight ') + + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) + mytester:assertlt(err,precision, 'error on bias ') + + local ferr,berr = jac.testIO(module,input2) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end local function criterionJacobianTest1D(cri, input, target) |