Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2013-11-18 20:08:20 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2013-11-18 20:09:14 +0400
commitf87f7415eb8fc3f1b96f28a40a7537dd57e0876f (patch)
tree8caa1f96168a2d062ecf82ca5b9499632958dbda /PairwiseDistance.lua
parent3ce35f0bac29d4a9645ba71ee9a6cde6d93091d2 (diff)
Fixed a bug in PairwiseDistance where the gradInput table isn't converted when Module.type function is called (this bug has always existed and is not due to the recent changes).
Diffstat (limited to 'PairwiseDistance.lua')
-rw-r--r--PairwiseDistance.lua15
1 files changed, 15 insertions, 0 deletions
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua
index 79569c9..b8c2f1e 100644
--- a/PairwiseDistance.lua
+++ b/PairwiseDistance.lua
@@ -86,3 +86,18 @@ function PairwiseDistance:updateGradInput(input, gradOutput)
self.gradInput[2]:zero():add(-1, self.gradInput[1])
return self.gradInput
end
+
+-- save away Module:type(type) for later use.
+PairwiseDistance._parent_type = parent.type
+
+-- Fix the bug where tmp = nn.PairwiseDistance:cuda() fails to convert table
+-- contents. We could, and probably should, change Module.lua to loop over
+-- and convert all the table elements in a module, but that might have
+-- repercussions, so this is a safer solution.
+function PairwiseDistance:type(type)
+ self:_parent_type(type) -- Call the parent (Module) type function
+ -- Now convert the left over table elements
+ self.gradInput[1] = self.gradInput[1]:type(type)
+ self.gradInput[2] = self.gradInput[2]:type(type)
+end
+