diff options
author | Jonathan Tompson <tompson@cims.nyu.edu> | 2013-10-19 03:10:06 +0400 |
---|---|---|
committer | Jonathan Tompson <tompson@cims.nyu.edu> | 2013-10-19 03:10:06 +0400 |
commit | 4be845ee0d7550daa15a13d8b7c0c14065aa8242 (patch) | |
tree | 63730a14bf66ad5185252bba6ad4d156cc57e0c6 | |
parent | d4792ac4eb4addf40c6c9fe27f0a810a7582ea0a (diff) |
Fixed the bprop in PairwiseDistance for pnorms other than one. The gradInput has always been wrong it seems; the sign of the gradient is correct but the magnitude was wrong. I also added a test in extra/nn/test/test.lua for nn.PairwiseDifference, which tests both the batch and non-batch code paths for a few different p-norms.
-rw-r--r-- | PairwiseDistance.lua | 24 | ||||
-rw-r--r-- | test/test.lua | 53 |
2 files changed, 73 insertions, 4 deletions
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua index 1752a88..a11c864 100644 --- a/PairwiseDistance.lua +++ b/PairwiseDistance.lua @@ -19,10 +19,7 @@ function PairwiseDistance:updateOutput(input) local diff = self.diff:zero() --local diff = torch.add(input[1], -1, input[2]) diff:add(input[1], -1, input[2]) - - if math.mod(self.norm, 2) == 1 then - diff:abs() - end + diff:abs() self.output:resize(input[1]:size(1)) self.output:zero() @@ -47,6 +44,25 @@ function PairwiseDistance:updateGradInput(input, gradOutput) self.gradInput[1]:add(-1, input[2]) if self.norm==1 then self.gradInput[1]:apply(mathsign) + else + -- See here for derivative of p-norm: + -- d/dx_k(||x||_p) = (x_k * abs(x_k)^(p-2)) / (||x||_p)^(p-1) + -- http://en.wikipedia.org/wiki/Norm_(mathematics) + self.gradInput[1]:cmul(torch.abs(self.gradInput[1]):pow(self.norm-2)) + if input[1]:dim() == 1 then + -- Avoid the expand for dimension 1 + self.gradInput[1]:mul(math.pow(self.output[1],-(self.norm-1))) + elseif input[1]:dim() == 2 then + -- This is a little messy... But it does work + self.outExpand = self.outExpand or self.output.new() + self.outExpand:resize(self.output:size(1), 1) + self.outExpand:copy(self.output) + self.outExpand:pow(-(self.norm-1)) + self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1), + self.gradInput[1]:size(2))) + else + error('input must be vector or matrix') + end end if input[1]:dim() == 1 then self.gradInput[1]:mul(gradOutput[1]) diff --git a/test/test.lua b/test/test.lua index dd6be22..5a1d469 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1530,6 +1530,59 @@ function nntest.Module_getParameters_7() mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector') end +function nntest.PairwiseDistance() + -- Note: testJacobian doesn't support table inputs, and rather than re-write + -- it so that it does, I'll just use a split table module on the input. + -- I assume both SplitTable and Sequential do not have bugs, otherwise this + -- test will break. + for p = 1,4 do -- test a few Lp norms + -- TEST CASE 1: non-batch inputs + local ini = math.random(10,20) + local input = torch.Tensor(2, ini):zero() + local module = nn.Sequential() + module:add(nn.SplitTable(1)) + module:add(nn.PairwiseDistance(p)) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, ' error on state ') + + local ferr,berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module)..' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module)..' - i/o backward err ') + + -- Also check that the forward prop result is correct + err = torch.dist(input:select(1,1), input:select(1,2), p) - + module:forward(input)[1] + mytester:assertlt(err,precision, ' error on non-batch fprop ') + + -- TEST CASE 2: batch input + local inj = math.random(10,20) + input = torch.Tensor(2, inj, ini):zero() + + -- (Rebuild the module to avoid correlated tests) + module = nn.Sequential() + module:add(nn.SplitTable(1)) + module:add(nn.PairwiseDistance(p)) + + err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, ' error on state ') + + -- Also check that the forward prop result is correct + -- manually calculate each distance separately + local inputa = torch.rand(inj,ini) + local inputb = torch.rand(inj,ini) + local dist_manual = torch.Tensor(inj) + for i=1, inputa:size(1) do + dist_manual[i] = torch.dist(inputa:select(1,i), inputb:select(1,i),p) + end + -- compare the distances to the module's fprop + local dist = module:forward(torch.cat(inputa,inputb,1):resize(2,inj,ini)) + err = dist - dist_manual + mytester:assertlt(err:norm(), precision, torch.typename(module) .. + ' error on batch fprop ') + end +end + mytester:add(nntest) if not nn then |