Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCédric Deltheil <cedric@moodstocks.com>2015-02-23 17:17:24 +0300
committerCédric Deltheil <cedric@moodstocks.com>2015-02-23 17:21:29 +0300
commit0e4bb234d6faf934ba1a25d4760d6d9292518ca6 (patch)
treeb6dd1261810ca5551a30476cb8b20b7372019ac0 /test.lua
parent822a742d3e94cac56d3c35dea75cbfca691db361 (diff)
CosineEmbeddingCriterion: zero grads below margin
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua13
1 files changed, 13 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 9604aac..42a0465 100644
--- a/test.lua
+++ b/test.lua
@@ -3124,6 +3124,19 @@ function nntest.BatchMMTransposeBoth()
end
end
+function nntest.CosineEmbeddingCriterion()
+ local v1 = torch.Tensor{1, 0}
+ local v2 = torch.Tensor{0.5, math.sqrt(3)*0.5}
+
+ local crit = nn.CosineEmbeddingCriterion(0.6)
+ local output = crit:forward({v1, v2}, -1) -- must be called before backward
+ local grads = crit:backward({v1, v2}, -1)
+
+ local zero = torch.Tensor(2):zero()
+ equal(grads[1], zero, 'gradient should be zero')
+ equal(grads[2], zero, 'gradient should be zero')
+end
+
mytester:add(nntest)
if not nn then