Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorPáidí Creed <paidi@swiftkey.net>2014-01-28 14:42:17 +0400
committerPáidí Creed <paidi@swiftkey.net>2014-01-28 14:42:17 +0400
commit2152758d904b4cab0ace02817203a65d92acbb10 (patch)
tree6a58c70942ce189262e1b270132e77a87195a016 /test
parentaa60b6e2be23beb899b3eca28c762793afea52a6 (diff)
parent60947473ba346a04c794dd63335633640351ae46 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'test')
-rw-r--r--test/test.lua54
1 files changed, 54 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 89db059..27bb114 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1576,6 +1576,60 @@ function nntest.Module_getParameters_7()
mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector')
end
+function nntest.PairwiseDistance()
+ -- Note: testJacobian doesn't support table inputs, and rather than re-write
+ -- it so that it does, I'll just use a split table module on the input.
+ -- I assume both SplitTable and Sequential do not have bugs, otherwise this
+ -- test will break.
+ for p = 1,4 do -- test a few Lp norms
+ -- TEST CASE 1: non-batch input, same code path but includes a resize
+ local ini = math.random(10,20)
+ local input = torch.Tensor(2, ini):zero()
+ local module = nn.Sequential()
+ module:add(nn.SplitTable(1))
+ module:add(nn.PairwiseDistance(p))
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, ' error on state ')
+
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module)..' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module)..' - i/o backward err ')
+
+ -- Also check that the forward prop result is correct.
+ input = torch.rand(2, ini)
+ err = torch.dist(input:select(1,1), input:select(1,2), p) -
+ module:forward(input)[1]
+ mytester:assertlt(err,precision, ' error on non-batch fprop ')
+
+ -- TEST CASE 2: batch input
+ local inj = math.random(10,20)
+ input = torch.Tensor(2, inj, ini):zero()
+
+ -- (Rebuild the module to avoid correlated tests)
+ module = nn.Sequential()
+ module:add(nn.SplitTable(1))
+ module:add(nn.PairwiseDistance(p))
+
+ err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, ' error on state ')
+
+ -- Also check that the forward prop result is correct.
+ -- manually calculate each distance separately
+ local inputa = torch.rand(inj,ini)
+ local inputb = torch.rand(inj,ini)
+ local dist_manual = torch.Tensor(inj)
+ for i=1, inputa:size(1) do
+ dist_manual[i] = torch.dist(inputa:select(1,i), inputb:select(1,i),p)
+ end
+ -- compare the distances to the module's fprop
+ local dist = module:forward(torch.cat(inputa,inputb,1):resize(2,inj,ini))
+ err = dist - dist_manual
+ mytester:assertlt(err:norm(), precision, torch.typename(module) ..
+ ' error on batch fprop ')
+ end
+end
+
mytester:add(nntest)
if not nn then