diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-04 17:53:04 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-02-09 21:08:41 +0300 |
commit | b4ebdf2f95ee9f1429825a0d7b0948721e407d82 (patch) | |
tree | 77b6183d0c258d07d5f5e6338c48ccb5fc5a877b /CosineDistance.lua | |
parent | 4a433463fcd17c4ec3242586da7e9fc859feb3ae (diff) |
nn.clearState
Diffstat (limited to 'CosineDistance.lua')
-rw-r--r-- | CosineDistance.lua | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/CosineDistance.lua b/CosineDistance.lua index d135e05..2988c65 100644 --- a/CosineDistance.lua +++ b/CosineDistance.lua @@ -73,6 +73,11 @@ function CosineDistance:updateGradInput(input, gradOutput) not_batch = true end + if #self.gradInput ~= 2 then + self.gradInput[1] = self.gradInput[1] or v1.new() + self.gradInput[2] = self.gradInput[2] or v1.new() + end + local gw1 = self.gradInput[1] local gw2 = self.gradInput[2] gw1:resizeAs(v1):copy(v2) @@ -97,3 +102,15 @@ function CosineDistance:updateGradInput(input, gradOutput) return self.gradInput end + +function CosineDistance:clearState() + nn.utils.clear(self, { + 'buffer', + 'w1', + 'w22', + 'w', + 'w32', + 'ones', + }) + return parent.clearState(self) +end |