diff options
author | Jonathan Tompson <tompson@cims.nyu.edu> | 2013-11-18 20:08:20 +0400 |
---|---|---|
committer | Jonathan Tompson <tompson@cims.nyu.edu> | 2013-11-18 20:09:14 +0400 |
commit | f87f7415eb8fc3f1b96f28a40a7537dd57e0876f (patch) | |
tree | 8caa1f96168a2d062ecf82ca5b9499632958dbda /PairwiseDistance.lua | |
parent | 3ce35f0bac29d4a9645ba71ee9a6cde6d93091d2 (diff) |
Fixed a bug in PairwiseDistance where the gradInput table isn't converted when Module.type function is called (this bug has always existed and is not due to the recent changes).
Diffstat (limited to 'PairwiseDistance.lua')
-rw-r--r-- | PairwiseDistance.lua | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua index 79569c9..b8c2f1e 100644 --- a/PairwiseDistance.lua +++ b/PairwiseDistance.lua @@ -86,3 +86,18 @@ function PairwiseDistance:updateGradInput(input, gradOutput) self.gradInput[2]:zero():add(-1, self.gradInput[1]) return self.gradInput end + +-- save away Module:type(type) for later use. +PairwiseDistance._parent_type = parent.type + +-- Fix the bug where tmp = nn.PairwiseDistance:cuda() fails to convert table +-- contents. We could, and probably should, change Module.lua to loop over +-- and convert all the table elements in a module, but that might have +-- repercussions, so this is a safer solution. +function PairwiseDistance:type(type) + self:_parent_type(type) -- Call the parent (Module) type function + -- Now convert the left over table elements + self.gradInput[1] = self.gradInput[1]:type(type) + self.gradInput[2] = self.gradInput[2]:type(type) +end + |