Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--PairwiseDistance.lua24
-rw-r--r--test/test.lua53
2 files changed, 73 insertions, 4 deletions
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua
index 1752a88..a11c864 100644
--- a/PairwiseDistance.lua
+++ b/PairwiseDistance.lua
@@ -19,10 +19,7 @@ function PairwiseDistance:updateOutput(input)
local diff = self.diff:zero()
--local diff = torch.add(input[1], -1, input[2])
diff:add(input[1], -1, input[2])
-
- if math.mod(self.norm, 2) == 1 then
- diff:abs()
- end
+ diff:abs()
self.output:resize(input[1]:size(1))
self.output:zero()
@@ -47,6 +44,25 @@ function PairwiseDistance:updateGradInput(input, gradOutput)
self.gradInput[1]:add(-1, input[2])
if self.norm==1 then
self.gradInput[1]:apply(mathsign)
+ else
+ -- See here for derivative of p-norm:
+ -- d/dx_k(||x||_p) = (x_k * abs(x_k)^(p-2)) / (||x||_p)^(p-1)
+ -- http://en.wikipedia.org/wiki/Norm_(mathematics)
+ self.gradInput[1]:cmul(torch.abs(self.gradInput[1]):pow(self.norm-2))
+ if input[1]:dim() == 1 then
+ -- Avoid the expand for dimension 1
+ self.gradInput[1]:mul(math.pow(self.output[1],-(self.norm-1)))
+ elseif input[1]:dim() == 2 then
+ -- This is a little messy... But it does work
+ self.outExpand = self.outExpand or self.output.new()
+ self.outExpand:resize(self.output:size(1), 1)
+ self.outExpand:copy(self.output)
+ self.outExpand:pow(-(self.norm-1))
+ self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1),
+ self.gradInput[1]:size(2)))
+ else
+ error('input must be vector or matrix')
+ end
end
if input[1]:dim() == 1 then
self.gradInput[1]:mul(gradOutput[1])
diff --git a/test/test.lua b/test/test.lua
index dd6be22..5a1d469 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1530,6 +1530,59 @@ function nntest.Module_getParameters_7()
mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector')
end
+function nntest.PairwiseDistance()
+ -- Note: testJacobian doesn't support table inputs, and rather than re-write
+ -- it so that it does, I'll just use a split table module on the input.
+ -- I assume both SplitTable and Sequential do not have bugs, otherwise this
+ -- test will break.
+ for p = 1,4 do -- test a few Lp norms
+ -- TEST CASE 1: non-batch inputs
+ local ini = math.random(10,20)
+ local input = torch.Tensor(2, ini):zero()
+ local module = nn.Sequential()
+ module:add(nn.SplitTable(1))
+ module:add(nn.PairwiseDistance(p))
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, ' error on state ')
+
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module)..' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module)..' - i/o backward err ')
+
+ -- Also check that the forward prop result is correct
+ err = torch.dist(input:select(1,1), input:select(1,2), p) -
+ module:forward(input)[1]
+ mytester:assertlt(err,precision, ' error on non-batch fprop ')
+
+ -- TEST CASE 2: batch input
+ local inj = math.random(10,20)
+ input = torch.Tensor(2, inj, ini):zero()
+
+ -- (Rebuild the module to avoid correlated tests)
+ module = nn.Sequential()
+ module:add(nn.SplitTable(1))
+ module:add(nn.PairwiseDistance(p))
+
+ err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, ' error on state ')
+
+ -- Also check that the forward prop result is correct
+ -- manually calculate each distance separately
+ local inputa = torch.rand(inj,ini)
+ local inputb = torch.rand(inj,ini)
+ local dist_manual = torch.Tensor(inj)
+ for i=1, inputa:size(1) do
+ dist_manual[i] = torch.dist(inputa:select(1,i), inputb:select(1,i),p)
+ end
+ -- compare the distances to the module's fprop
+ local dist = module:forward(torch.cat(inputa,inputb,1):resize(2,inj,ini))
+ err = dist - dist_manual
+ mytester:assertlt(err:norm(), precision, torch.typename(module) ..
+ ' error on batch fprop ')
+ end
+end
+
mytester:add(nntest)
if not nn then