diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-07-21 22:13:00 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-07-21 22:13:00 +0300 |
commit | 55760ed22852e7ab00c2d2424bd427882890b41a (patch) | |
tree | f5560dff6c8bf73a0dc7108b1d9dd06093f5857b /test.lua | |
parent | b29c6fd53bbad1935a321a3ffb5a8eb26832cdbf (diff) | |
parent | fcb300cf07bf03b924fd9766adc5cff0f838593d (diff) |
Merge pull request #300 from fmassa/cosinedist_batch
CosineDistance supports batch mode
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 50 |
1 files changed, 50 insertions, 0 deletions
@@ -3724,6 +3724,56 @@ function nntest.BatchMMTransposeBoth() end end +function nntest.CosineDistance() + local indim = math.random(1,10) + local input = {torch.rand(indim),torch.rand(indim)} + + -- check forward against previous implementation + local module = nn.CosineDistance() + + local w1 = input[1]:dot(input[2]) + local w2 = math.sqrt(input[1]:dot(input[1])) + local w3 = math.sqrt(input[2]:dot(input[2])) + local output_old = w1/w2/w3 + + local output = module:forward(input) + + mytester:assertlt(math.abs(output_old-output[1]),precision,'error on forward ') + + + -- check gradients + -- Note: testJacobian doesn't support table inputs, and rather than re-write + -- it so that it does, I'll just use a split table module on the input. + -- I assume both SplitTable and Sequential do not have bugs, otherwise this + -- test will break. + local input = torch.rand(2,indim) + local module = nn.Sequential() + module:add(nn.SplitTable(1)) + module:add(nn.CosineDistance()) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + -- IO + local ferr,berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- batch + -- rebuild module to avoid correlated tests + local module = nn.Sequential() + module:add(nn.SplitTable(1)) + module:add(nn.CosineDistance()) + + local nframes = math.random(1,10) + local indim = math.random(1,10) + local input = torch.rand(2,nframes,indim) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'batch error on state ') + +end + function nntest.CosineEmbeddingCriterion() local v1 = torch.Tensor{1, 0} local v2 = torch.Tensor{0.5, math.sqrt(3)*0.5} |