Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCédric Deltheil <cedric@moodstocks.com>2015-02-23 17:17:24 +0300
committerCédric Deltheil <cedric@moodstocks.com>2015-02-23 17:21:29 +0300
commit0e4bb234d6faf934ba1a25d4760d6d9292518ca6 (patch)
treeb6dd1261810ca5551a30476cb8b20b7372019ac0 /CosineEmbeddingCriterion.lua
parent822a742d3e94cac56d3c35dea75cbfca691db361 (diff)
CosineEmbeddingCriterion: zero grads below margin
Diffstat (limited to 'CosineEmbeddingCriterion.lua')
-rw-r--r--CosineEmbeddingCriterion.lua16
1 files changed, 9 insertions, 7 deletions
diff --git a/CosineEmbeddingCriterion.lua b/CosineEmbeddingCriterion.lua
index 293ae23..2dc3e9c 100644
--- a/CosineEmbeddingCriterion.lua
+++ b/CosineEmbeddingCriterion.lua
@@ -32,16 +32,18 @@ function CosineEmbeddingCriterion:updateGradInput(input, y)
gw2:resizeAs(v1)
gw1:zero()
- gw1:add(1/(self.w2*self.w3), v2)
- gw1:add(-self.w1/(self.w22*self.w2*self.w3), v1)
-
gw2:zero()
- gw2:add(1/(self.w2*self.w3), v1)
- gw2:add(-self.w1/(self.w32*self.w2*self.w3), v2)
+ if self.output > 0 then
+ gw1:add(1/(self.w2*self.w3), v2)
+ gw1:add(-self.w1/(self.w22*self.w2*self.w3), v1)
+
+ gw2:add(1/(self.w2*self.w3), v1)
+ gw2:add(-self.w1/(self.w32*self.w2*self.w3), v2)
+ end
if y == 1 then
- gw1 = -gw1
- gw2 = -gw2
+ gw1:mul(-1)
+ gw2:mul(-1)
end
self.gradInput = {gw1, gw2}
return self.gradInput