Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-07-21 22:13:00 +0300
committerSoumith Chintala <soumith@gmail.com>2015-07-21 22:13:00 +0300
commit55760ed22852e7ab00c2d2424bd427882890b41a (patch)
treef5560dff6c8bf73a0dc7108b1d9dd06093f5857b /test.lua
parentb29c6fd53bbad1935a321a3ffb5a8eb26832cdbf (diff)
parentfcb300cf07bf03b924fd9766adc5cff0f838593d (diff)
Merge pull request #300 from fmassa/cosinedist_batch
CosineDistance supports batch mode
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua50
1 files changed, 50 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 3f37dac..cd48433 100644
--- a/test.lua
+++ b/test.lua
@@ -3724,6 +3724,56 @@ function nntest.BatchMMTransposeBoth()
end
end
+function nntest.CosineDistance()
+ local indim = math.random(1,10)
+ local input = {torch.rand(indim),torch.rand(indim)}
+
+ -- check forward against previous implementation
+ local module = nn.CosineDistance()
+
+ local w1 = input[1]:dot(input[2])
+ local w2 = math.sqrt(input[1]:dot(input[1]))
+ local w3 = math.sqrt(input[2]:dot(input[2]))
+ local output_old = w1/w2/w3
+
+ local output = module:forward(input)
+
+ mytester:assertlt(math.abs(output_old-output[1]),precision,'error on forward ')
+
+
+ -- check gradients
+ -- Note: testJacobian doesn't support table inputs, and rather than re-write
+ -- it so that it does, I'll just use a split table module on the input.
+ -- I assume both SplitTable and Sequential do not have bugs, otherwise this
+ -- test will break.
+ local input = torch.rand(2,indim)
+ local module = nn.Sequential()
+ module:add(nn.SplitTable(1))
+ module:add(nn.CosineDistance())
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ -- IO
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ -- batch
+ -- rebuild module to avoid correlated tests
+ local module = nn.Sequential()
+ module:add(nn.SplitTable(1))
+ module:add(nn.CosineDistance())
+
+ local nframes = math.random(1,10)
+ local indim = math.random(1,10)
+ local input = torch.rand(2,nframes,indim)
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'batch error on state ')
+
+end
+
function nntest.CosineEmbeddingCriterion()
local v1 = torch.Tensor{1, 0}
local v2 = torch.Tensor{0.5, math.sqrt(3)*0.5}