Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'PairwiseDistance.lua')
-rw-r--r--PairwiseDistance.lua53
1 files changed, 47 insertions, 6 deletions
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua
index d9e6f81..affc2e5 100644
--- a/PairwiseDistance.lua
+++ b/PairwiseDistance.lua
@@ -6,6 +6,7 @@ function PairwiseDistance:__init(p)
-- state
self.gradInput = {torch.Tensor(), torch.Tensor()}
self.output = torch.Tensor(1)
+ self.diff = torch.Tensor()
self.norm=p
end
@@ -17,8 +18,8 @@ function PairwiseDistance:updateOutput(input)
self.diff:resizeAs(input[1])
local diff = self.diff:zero()
- --local diff = torch.add(input[1], -1, input[2])
diff:add(input[1], -1, input[2])
+ diff:abs()
self.output:resize(input[1]:size(1))
self.output:zero()
@@ -27,7 +28,10 @@ function PairwiseDistance:updateOutput(input)
else
error('input must be vector or matrix')
end
-
+ if input[1]:dim() > 2 then
+ error('input must be vector or matrix')
+ end
+
return self.output
end
@@ -37,16 +41,39 @@ local function mathsign(x)
end
function PairwiseDistance:updateGradInput(input, gradOutput)
+ if input[1]:dim() > 2 then
+ error('input must be vector or matrix')
+ end
+
self.gradInput[1]:resize(input[1]:size())
self.gradInput[2]:resize(input[2]:size())
self.gradInput[1]:copy(input[1])
- self.gradInput[1]:add(-1, input[2])
+ self.gradInput[1]:add(-1, input[2])
+
if self.norm==1 then
self.gradInput[1]:apply(mathsign)
+ else
+ -- Note: derivative of p-norm:
+ -- d/dx_k(||x||_p) = (x_k * abs(x_k)^(p-2)) / (||x||_p)^(p-1)
+ if (self.norm > 2) then
+ self.gradInput[1]:cmul(self.gradInput[1]:clone():abs():pow(self.norm-2))
+ end
+
+ if (input[1]:dim() > 1) then
+ self.outExpand = self.outExpand or self.output.new()
+ self.outExpand:resize(self.output:size(1), 1)
+ self.outExpand:copy(self.output)
+ self.outExpand:add(1.0e-6) -- Prevent divide by zero errors
+ self.outExpand:pow(-(self.norm-1))
+ self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1),
+ self.gradInput[1]:size(2)))
+ else
+ self.gradInput[1]:mul(math.pow(self.output[1] + 1e-6, -(self.norm-1)))
+ end
end
if input[1]:dim() == 1 then
self.gradInput[1]:mul(gradOutput[1])
- elseif input[1]:dim() == 2 then
+ else
self.grad = self.grad or gradOutput.new()
self.ones = self.ones or gradOutput.new()
@@ -55,9 +82,23 @@ function PairwiseDistance:updateGradInput(input, gradOutput)
self.grad:addr(gradOutput, self.ones)
self.gradInput[1]:cmul(self.grad)
- else
- error('input must be vector or matrix')
end
self.gradInput[2]:zero():add(-1, self.gradInput[1])
return self.gradInput
end
+
+-- save away Module:type(type) for later use.
+PairwiseDistance._parent_type = parent.type
+
+-- Fix the bug where tmp = nn.PairwiseDistance:cuda() fails to convert table
+-- contents. We could, and probably should, change Module.lua to loop over
+-- and convert all the table elements in a module, but that might have
+-- repercussions, so this is a safer solution.
+function PairwiseDistance:type(type)
+ self:_parent_type(type) -- Call the parent (Module) type function
+ -- Now convert the left over table elements
+ self.gradInput[1] = self.gradInput[1]:type(type)
+ self.gradInput[2] = self.gradInput[2]:type(type)
+ return self
+end
+