From 94f4b620b5e31de5abf9450189b9464e84849d5d Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sun, 21 Jun 2015 13:17:10 +0200 Subject: CosineDistance supports batch mode keeps backward compatibility with saved models --- test.lua | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) (limited to 'test.lua') diff --git a/test.lua b/test.lua index 55818e1..773fc43 100644 --- a/test.lua +++ b/test.lua @@ -3577,6 +3577,56 @@ function nntest.BatchMMTransposeBoth() end end +function nntest.CosineDistance() + local indim = math.random(1,10) + local input = {torch.rand(indim),torch.rand(indim)} + + -- check forward against previous implementation + local module = nn.CosineDistance() + + local w1 = input[1]:dot(input[2]) + local w2 = math.sqrt(input[1]:dot(input[1])) + local w3 = math.sqrt(input[2]:dot(input[2])) + local output_old = w1/w2/w3 + + local output = module:forward(input) + + mytester:assertlt(math.abs(output_old-output[1]),precision,'error on forward ') + + + -- check gradients + -- Note: testJacobian doesn't support table inputs, and rather than re-write + -- it so that it does, I'll just use a split table module on the input. + -- I assume both SplitTable and Sequential do not have bugs, otherwise this + -- test will break. + local input = torch.rand(2,indim) + local module = nn.Sequential() + module:add(nn.SplitTable(1)) + module:add(nn.CosineDistance()) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + -- IO + local ferr,berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- batch + -- rebuild module to avoid correlated tests + local module = nn.Sequential() + module:add(nn.SplitTable(1)) + module:add(nn.CosineDistance()) + + local nframes = math.random(1,10) + local indim = math.random(1,10) + local input = torch.rand(2,nframes,indim) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'batch error on state ') + +end + function nntest.CosineEmbeddingCriterion() local v1 = torch.Tensor{1, 0} local v2 = torch.Tensor{0.5, math.sqrt(3)*0.5} -- cgit v1.2.3