diff options
author | Cédric Deltheil <cedric@moodstocks.com> | 2015-02-23 17:17:24 +0300 |
---|---|---|
committer | Cédric Deltheil <cedric@moodstocks.com> | 2015-02-23 17:21:29 +0300 |
commit | 0e4bb234d6faf934ba1a25d4760d6d9292518ca6 (patch) | |
tree | b6dd1261810ca5551a30476cb8b20b7372019ac0 /CosineEmbeddingCriterion.lua | |
parent | 822a742d3e94cac56d3c35dea75cbfca691db361 (diff) |
CosineEmbeddingCriterion: zero grads below margin
Diffstat (limited to 'CosineEmbeddingCriterion.lua')
-rw-r--r-- | CosineEmbeddingCriterion.lua | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/CosineEmbeddingCriterion.lua b/CosineEmbeddingCriterion.lua index 293ae23..2dc3e9c 100644 --- a/CosineEmbeddingCriterion.lua +++ b/CosineEmbeddingCriterion.lua @@ -32,16 +32,18 @@ function CosineEmbeddingCriterion:updateGradInput(input, y) gw2:resizeAs(v1) gw1:zero() - gw1:add(1/(self.w2*self.w3), v2) - gw1:add(-self.w1/(self.w22*self.w2*self.w3), v1) - gw2:zero() - gw2:add(1/(self.w2*self.w3), v1) - gw2:add(-self.w1/(self.w32*self.w2*self.w3), v2) + if self.output > 0 then + gw1:add(1/(self.w2*self.w3), v2) + gw1:add(-self.w1/(self.w22*self.w2*self.w3), v1) + + gw2:add(1/(self.w2*self.w3), v1) + gw2:add(-self.w1/(self.w32*self.w2*self.w3), v2) + end if y == 1 then - gw1 = -gw1 - gw2 = -gw2 + gw1:mul(-1) + gw2:mul(-1) end self.gradInput = {gw1, gw2} return self.gradInput |