Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-02-04 17:53:04 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-02-09 21:08:41 +0300
commitb4ebdf2f95ee9f1429825a0d7b0948721e407d82 (patch)
tree77b6183d0c258d07d5f5e6338c48ccb5fc5a877b /CosineDistance.lua
parent4a433463fcd17c4ec3242586da7e9fc859feb3ae (diff)
nn.clearState
Diffstat (limited to 'CosineDistance.lua')
-rw-r--r--CosineDistance.lua17
1 files changed, 17 insertions, 0 deletions
diff --git a/CosineDistance.lua b/CosineDistance.lua
index d135e05..2988c65 100644
--- a/CosineDistance.lua
+++ b/CosineDistance.lua
@@ -73,6 +73,11 @@ function CosineDistance:updateGradInput(input, gradOutput)
not_batch = true
end
+ if #self.gradInput ~= 2 then
+ self.gradInput[1] = self.gradInput[1] or v1.new()
+ self.gradInput[2] = self.gradInput[2] or v1.new()
+ end
+
local gw1 = self.gradInput[1]
local gw2 = self.gradInput[2]
gw1:resizeAs(v1):copy(v2)
@@ -97,3 +102,15 @@ function CosineDistance:updateGradInput(input, gradOutput)
return self.gradInput
end
+
+function CosineDistance:clearState()
+ nn.utils.clear(self, {
+ 'buffer',
+ 'w1',
+ 'w22',
+ 'w',
+ 'w32',
+ 'ones',
+ })
+ return parent.clearState(self)
+end