diff options
author | Jonathan Tompson <tompson@cims.nyu.edu> | 2013-10-20 19:51:07 +0400 |
---|---|---|
committer | Jonathan Tompson <tompson@cims.nyu.edu> | 2013-10-20 19:51:07 +0400 |
commit | 7286784cbd0af841501bfadb98a4a2cee28199b6 (patch) | |
tree | 898cebbef05e7e99dfa05af345f180c971ee702b | |
parent | e1fbc0cccab633fd0615dc64ba9fd52f64072622 (diff) |
fixed a bug in Pairwise distance when the output Lp norm is zero (which results in a divide by zero issue). Rewrote PairwiseDistance following Clement's suggestion to only have one codepath. Fixed a small bug in extra/test/test.lua where the input to the non-batch fprop test was zero.
-rw-r--r-- | PairwiseDistance.lua | 95 | ||||
-rw-r--r-- | test/test.lua | 7 |
2 files changed, 54 insertions, 48 deletions
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua index 4941210..f9f77f2 100644 --- a/PairwiseDistance.lua +++ b/PairwiseDistance.lua @@ -6,28 +6,30 @@ function PairwiseDistance:__init(p) -- state self.gradInput = {torch.Tensor(), torch.Tensor()} self.output = torch.Tensor(1) + self.diff = torch.Tensor() self.norm=p end function PairwiseDistance:updateOutput(input) - if input[1]:dim() == 1 then - self.output[1]=input[1]:dist(input[2],self.norm) - elseif input[1]:dim() == 2 then - self.diff = self.diff or input[1].new() - self.diff:resizeAs(input[1]) - - local diff = self.diff:zero() - --local diff = torch.add(input[1], -1, input[2]) - diff:add(input[1], -1, input[2]) - diff:abs() - - self.output:resize(input[1]:size(1)) - self.output:zero() - self.output:add(diff:pow(self.norm):sum(2)) - self.output:pow(1./self.norm) - else + if input[1]:dim() > 2 then error('input must be vector or matrix') end + if input[1]:dim() == 1 then + -- Reshape the input so it always looks like a batch (avoids multiple + -- code-paths). (Clement's good suggestion) + input[1]:resize(1,input[1]:size(1)) + input[2]:resize(1,input[2]:size(1)) + end + + self.diff:resizeAs(input[1]) + self.diff:zero() + self.diff:add(input[1], -1, input[2]) + self.diff:abs() + + self.output:resize(input[1]:size(1)) + self.output:zero() + self.output:add(self.diff:pow(self.norm):sum(2)) + self.output:pow(1./self.norm) return self.output end @@ -38,46 +40,49 @@ local function mathsign(x) end function PairwiseDistance:updateGradInput(input, gradOutput) + if input[1]:dim() > 2 then + error('input must be vector or matrix') + end + + if input[1]:dim() == 1 then + -- Reshape the input so it always looks like a batch (avoids multiple + -- code-paths). (Clement's good suggestion) + input[1]:resize(1,input[1]:size(1)) + input[2]:resize(1,input[2]:size(1)) + end + self.gradInput[1]:resize(input[1]:size()) self.gradInput[2]:resize(input[2]:size()) self.gradInput[1]:copy(input[1]) - self.gradInput[1]:add(-1, input[2]) + self.gradInput[1]:add(-1, input[2]) + if self.norm==1 then self.gradInput[1]:apply(mathsign) else - -- See here for derivative of p-norm: + -- Note: derivative of p-norm: -- d/dx_k(||x||_p) = (x_k * abs(x_k)^(p-2)) / (||x||_p)^(p-1) - -- http://en.wikipedia.org/wiki/Norm_(mathematics) - self.gradInput[1]:cmul(self.gradInput[1]:clone():abs():pow(self.norm-2)) - if input[1]:dim() == 1 then - -- Avoid the expand for dimension 1 - self.gradInput[1]:mul(math.pow(self.output[1],-(self.norm-1))) - elseif input[1]:dim() == 2 then - -- This is a little messy... But it does work - self.outExpand = self.outExpand or self.output.new() - self.outExpand:resize(self.output:size(1), 1) - self.outExpand:copy(self.output) - self.outExpand:pow(-(self.norm-1)) - self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1), - self.gradInput[1]:size(2))) - else - error('input must be vector or matrix') + if (self.norm > 2) then + self.gradInput[1]:cmul(self.gradInput[1]:clone():abs():pow(self.norm-2)) end + + self.outExpand = self.outExpand or self.output.new() + self.outExpand:resize(self.output:size(1), 1) + self.outExpand:copy(self.output) + self.outExpand:add(1.0e-6) -- Prevent divide by zero errors + self.outExpand:pow(-(self.norm-1)) + self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1), + self.gradInput[1]:size(2))) end - if input[1]:dim() == 1 then - self.gradInput[1]:mul(gradOutput[1]) - elseif input[1]:dim() == 2 then - self.grad = self.grad or gradOutput.new() - self.ones = self.ones or gradOutput.new() + + self.grad = self.grad or gradOutput.new() + self.ones = self.ones or gradOutput.new() - self.grad:resizeAs(input[1]):zero() - self.ones:resize(input[1]:size(2)):fill(1) + self.grad:resizeAs(input[1]):zero() + self.ones:resize(input[1]:size(2)):fill(1) - self.grad:addr(gradOutput, self.ones) - self.gradInput[1]:cmul(self.grad) - else - error('input must be vector or matrix') - end + self.grad:addr(gradOutput, self.ones) + self.gradInput[1]:cmul(self.grad) + self.gradInput[2]:zero():add(-1, self.gradInput[1]) return self.gradInput end diff --git a/test/test.lua b/test/test.lua index 5a1d469..0d54e3d 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1536,7 +1536,7 @@ function nntest.PairwiseDistance() -- I assume both SplitTable and Sequential do not have bugs, otherwise this -- test will break. for p = 1,4 do -- test a few Lp norms - -- TEST CASE 1: non-batch inputs + -- TEST CASE 1: non-batch input, same code path but includes a resize local ini = math.random(10,20) local input = torch.Tensor(2, ini):zero() local module = nn.Sequential() @@ -1550,7 +1550,8 @@ function nntest.PairwiseDistance() mytester:asserteq(ferr, 0, torch.typename(module)..' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module)..' - i/o backward err ') - -- Also check that the forward prop result is correct + -- Also check that the forward prop result is correct. + input = torch.rand(2, ini) err = torch.dist(input:select(1,1), input:select(1,2), p) - module:forward(input)[1] mytester:assertlt(err,precision, ' error on non-batch fprop ') @@ -1567,7 +1568,7 @@ function nntest.PairwiseDistance() err = jac.testJacobian(module,input) mytester:assertlt(err,precision, ' error on state ') - -- Also check that the forward prop result is correct + -- Also check that the forward prop result is correct. -- manually calculate each distance separately local inputa = torch.rand(inj,ini) local inputb = torch.rand(inj,ini) |