Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2013-10-20 19:51:07 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2013-10-20 19:51:07 +0400
commit7286784cbd0af841501bfadb98a4a2cee28199b6 (patch)
tree898cebbef05e7e99dfa05af345f180c971ee702b /PairwiseDistance.lua
parente1fbc0cccab633fd0615dc64ba9fd52f64072622 (diff)
fixed a bug in Pairwise distance when the output Lp norm is zero (which results in a divide by zero issue). Rewrote PairwiseDistance following Clement's suggestion to only have one codepath. Fixed a small bug in extra/test/test.lua where the input to the non-batch fprop test was zero.
Diffstat (limited to 'PairwiseDistance.lua')
-rw-r--r--PairwiseDistance.lua95
1 files changed, 50 insertions, 45 deletions
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua
index 4941210..f9f77f2 100644
--- a/PairwiseDistance.lua
+++ b/PairwiseDistance.lua
@@ -6,28 +6,30 @@ function PairwiseDistance:__init(p)
-- state
self.gradInput = {torch.Tensor(), torch.Tensor()}
self.output = torch.Tensor(1)
+ self.diff = torch.Tensor()
self.norm=p
end
function PairwiseDistance:updateOutput(input)
- if input[1]:dim() == 1 then
- self.output[1]=input[1]:dist(input[2],self.norm)
- elseif input[1]:dim() == 2 then
- self.diff = self.diff or input[1].new()
- self.diff:resizeAs(input[1])
-
- local diff = self.diff:zero()
- --local diff = torch.add(input[1], -1, input[2])
- diff:add(input[1], -1, input[2])
- diff:abs()
-
- self.output:resize(input[1]:size(1))
- self.output:zero()
- self.output:add(diff:pow(self.norm):sum(2))
- self.output:pow(1./self.norm)
- else
+ if input[1]:dim() > 2 then
error('input must be vector or matrix')
end
+ if input[1]:dim() == 1 then
+ -- Reshape the input so it always looks like a batch (avoids multiple
+ -- code-paths). (Clement's good suggestion)
+ input[1]:resize(1,input[1]:size(1))
+ input[2]:resize(1,input[2]:size(1))
+ end
+
+ self.diff:resizeAs(input[1])
+ self.diff:zero()
+ self.diff:add(input[1], -1, input[2])
+ self.diff:abs()
+
+ self.output:resize(input[1]:size(1))
+ self.output:zero()
+ self.output:add(self.diff:pow(self.norm):sum(2))
+ self.output:pow(1./self.norm)
return self.output
end
@@ -38,46 +40,49 @@ local function mathsign(x)
end
function PairwiseDistance:updateGradInput(input, gradOutput)
+ if input[1]:dim() > 2 then
+ error('input must be vector or matrix')
+ end
+
+ if input[1]:dim() == 1 then
+ -- Reshape the input so it always looks like a batch (avoids multiple
+ -- code-paths). (Clement's good suggestion)
+ input[1]:resize(1,input[1]:size(1))
+ input[2]:resize(1,input[2]:size(1))
+ end
+
self.gradInput[1]:resize(input[1]:size())
self.gradInput[2]:resize(input[2]:size())
self.gradInput[1]:copy(input[1])
- self.gradInput[1]:add(-1, input[2])
+ self.gradInput[1]:add(-1, input[2])
+
if self.norm==1 then
self.gradInput[1]:apply(mathsign)
else
- -- See here for derivative of p-norm:
+ -- Note: derivative of p-norm:
-- d/dx_k(||x||_p) = (x_k * abs(x_k)^(p-2)) / (||x||_p)^(p-1)
- -- http://en.wikipedia.org/wiki/Norm_(mathematics)
- self.gradInput[1]:cmul(self.gradInput[1]:clone():abs():pow(self.norm-2))
- if input[1]:dim() == 1 then
- -- Avoid the expand for dimension 1
- self.gradInput[1]:mul(math.pow(self.output[1],-(self.norm-1)))
- elseif input[1]:dim() == 2 then
- -- This is a little messy... But it does work
- self.outExpand = self.outExpand or self.output.new()
- self.outExpand:resize(self.output:size(1), 1)
- self.outExpand:copy(self.output)
- self.outExpand:pow(-(self.norm-1))
- self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1),
- self.gradInput[1]:size(2)))
- else
- error('input must be vector or matrix')
+ if (self.norm > 2) then
+ self.gradInput[1]:cmul(self.gradInput[1]:clone():abs():pow(self.norm-2))
end
+
+ self.outExpand = self.outExpand or self.output.new()
+ self.outExpand:resize(self.output:size(1), 1)
+ self.outExpand:copy(self.output)
+ self.outExpand:add(1.0e-6) -- Prevent divide by zero errors
+ self.outExpand:pow(-(self.norm-1))
+ self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1),
+ self.gradInput[1]:size(2)))
end
- if input[1]:dim() == 1 then
- self.gradInput[1]:mul(gradOutput[1])
- elseif input[1]:dim() == 2 then
- self.grad = self.grad or gradOutput.new()
- self.ones = self.ones or gradOutput.new()
+
+ self.grad = self.grad or gradOutput.new()
+ self.ones = self.ones or gradOutput.new()
- self.grad:resizeAs(input[1]):zero()
- self.ones:resize(input[1]:size(2)):fill(1)
+ self.grad:resizeAs(input[1]):zero()
+ self.ones:resize(input[1]:size(2)):fill(1)
- self.grad:addr(gradOutput, self.ones)
- self.gradInput[1]:cmul(self.grad)
- else
- error('input must be vector or matrix')
- end
+ self.grad:addr(gradOutput, self.ones)
+ self.gradInput[1]:cmul(self.grad)
+
self.gradInput[2]:zero():add(-1, self.gradInput[1])
return self.gradInput
end