diff options
author | Páidí Creed <paidi@swiftkey.net> | 2014-01-28 14:42:17 +0400 |
---|---|---|
committer | Páidí Creed <paidi@swiftkey.net> | 2014-01-28 14:42:17 +0400 |
commit | 2152758d904b4cab0ace02817203a65d92acbb10 (patch) | |
tree | 6a58c70942ce189262e1b270132e77a87195a016 /test | |
parent | aa60b6e2be23beb899b3eca28c762793afea52a6 (diff) | |
parent | 60947473ba346a04c794dd63335633640351ae46 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 89db059..27bb114 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1576,6 +1576,60 @@ function nntest.Module_getParameters_7() mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector') end +function nntest.PairwiseDistance() + -- Note: testJacobian doesn't support table inputs, and rather than re-write + -- it so that it does, I'll just use a split table module on the input. + -- I assume both SplitTable and Sequential do not have bugs, otherwise this + -- test will break. + for p = 1,4 do -- test a few Lp norms + -- TEST CASE 1: non-batch input, same code path but includes a resize + local ini = math.random(10,20) + local input = torch.Tensor(2, ini):zero() + local module = nn.Sequential() + module:add(nn.SplitTable(1)) + module:add(nn.PairwiseDistance(p)) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, ' error on state ') + + local ferr,berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module)..' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module)..' - i/o backward err ') + + -- Also check that the forward prop result is correct. + input = torch.rand(2, ini) + err = torch.dist(input:select(1,1), input:select(1,2), p) - + module:forward(input)[1] + mytester:assertlt(err,precision, ' error on non-batch fprop ') + + -- TEST CASE 2: batch input + local inj = math.random(10,20) + input = torch.Tensor(2, inj, ini):zero() + + -- (Rebuild the module to avoid correlated tests) + module = nn.Sequential() + module:add(nn.SplitTable(1)) + module:add(nn.PairwiseDistance(p)) + + err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, ' error on state ') + + -- Also check that the forward prop result is correct. + -- manually calculate each distance separately + local inputa = torch.rand(inj,ini) + local inputb = torch.rand(inj,ini) + local dist_manual = torch.Tensor(inj) + for i=1, inputa:size(1) do + dist_manual[i] = torch.dist(inputa:select(1,i), inputb:select(1,i),p) + end + -- compare the distances to the module's fprop + local dist = module:forward(torch.cat(inputa,inputb,1):resize(2,inj,ini)) + err = dist - dist_manual + mytester:assertlt(err:norm(), precision, torch.typename(module) .. + ' error on batch fprop ') + end +end + mytester:add(nntest) if not nn then |